// Copyright 2016 The Gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_ #define GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_ #include "multi_thread_common.h" #include "single_thread_transform.h" namespace gemmlowp { namespace meta { namespace internal { const int kTransformTaskOverhead = 128000; const int kMinTransformTaskSize = 32000; template inline bool PrepareTransform1DTasks(MultiThreadingContext* context, const Params& params, int kernel_size, std::vector* task_params) { typedef Transform1DUtil Util; const int max_threads = ResolveMaxThreads(context->max_num_threads()); const int task_size = Util::EstimateComputeCost(params.kernel); const int max_tasks_by_size = (task_size - kTransformTaskOverhead) / kMinTransformTaskSize; const int real_tasks = std::max(1, std::min(max_threads, max_tasks_by_size)); if (real_tasks == 1) { return false; } const int chunk = params.kernel.count / real_tasks; for (int i = 0; i < real_tasks - 1; ++i) { task_params->push_back(params); Params& task = task_params->back(); task.kernel.count = chunk; task.input = Util::OffsetInput(params.kernel, params.input, i * chunk); task.output = Util::OffsetOutput(params.kernel, params.output, i * chunk); } task_params->push_back(params); Params& task = task_params->back(); const int sum_chunk = (real_tasks - 1) * chunk; task.kernel.count = params.kernel.count - sum_chunk; task.input = Util::OffsetInput(params.kernel, params.input, sum_chunk); task.output = Util::OffsetOutput(params.kernel, params.output, sum_chunk); return true; } template struct Transform1DTaskRunner : gemmlowp::Task { Transform1DTaskRunner(const Params& params) : params(params) {} void Run() override { Transform1D(params); } Params params; }; } // namespace internal template inline void MultiThreadTransform1D(MultiThreadingContext* context, const Params& params) { typedef internal::Transform1DTaskRunner TaskRunnerType; std::vector task_params; if (!internal::PrepareTransform1DTasks( context, params, kernel_size, &task_params)) { Transform1D(params); return; } auto workers_pool = context->workers_pool(); std::vector tasks; std::for_each(task_params.begin(), task_params.end(), [tasks](Params* param) { tasks.push_back(new TaskRunnerType(param)); }); workers_pool->Execute(tasks); } } // namespace meta } // namespace gemmlowp #endif // GEMMLOWP_META_MULTI_THREAD_TRANSFORM_H_