// Copyright 2017 The Gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // dispatch_gemm_shape.h: dispatch GEMM calls according to their shape #ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ #define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ #include "../internal/kernel_default.h" #include "../public/map.h" #include "../public/output_stages.h" #include "multi_thread_gemm.h" namespace gemmlowp { template struct TransposeImpl { typedef T DstType; static T Run(const T& t) { return t; } }; template using TransposeType = typename TransposeImpl::DstType; template TransposeType Transpose(const T& t) { return TransposeImpl::Run(t); } template struct TransposeMapOrder { static constexpr MapOrder Value = Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor; }; template struct TransposeVectorShape { static constexpr VectorShape Value = Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row; }; template struct TransposeImpl> { typedef VectorMap SrcType; static constexpr VectorShape TransposedShape = TransposeVectorShape::Value; typedef VectorMap DstType; static DstType Run(const SrcType& src) { return DstType(src.data(), src.size()); } }; template struct TransposeImpl> { typedef MatrixMap SrcType; static constexpr MapOrder TransposedOrder = TransposeMapOrder::Value; typedef MatrixMap DstType; static DstType Run(const SrcType& src) { return DstType(src.data(), src.cols(), src.rows(), src.stride()); } }; template struct TransposeImpl> { typedef OutputStageQuantizeDownInt32ToUint8ScalePC SrcType; static const VectorShape TransposedShape = TransposeVectorShape::Value; typedef OutputStageQuantizeDownInt32ToUint8ScalePC DstType; static DstType Run(const SrcType& src) { DstType dst; dst.result_shift = src.result_shift; dst.result_offset = Transpose(src.result_offset); dst.result_mult_int = Transpose(src.result_mult_int); return dst; } }; template struct TransposeImpl> { typedef OutputStageScaleInt32ByFixedPointAndExponentPC SrcType; static const VectorShape TransposedShape = TransposeVectorShape::Value; typedef OutputStageScaleInt32ByFixedPointAndExponentPC DstType; static DstType Run(const SrcType& src) { DstType dst; dst.result_fixedpoint_multiplier = Transpose(src.result_fixedpoint_multiplier); dst.result_exponent = Transpose(src.result_exponent); dst.result_offset_after_shift = src.result_offset_after_shift; return dst; } }; template struct TransposeImpl> { typedef OutputStageBiasAddition SrcType; typedef TransposeType TransposedVectorMapType; typedef OutputStageBiasAddition DstType; static DstType Run(const SrcType& src) { DstType dst; dst.bias_vector = Transpose(src.bias_vector); return dst; } }; // TODO(benoitjacob) - does anyone understand C++ variadic templates? // How to use them to implement TransposeTuple? Note: there are lots // of answers on StackOverflow but they seem to all involve either // C++14/C++17 (we can only use C++11) or lots of abstract nonsense. inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; } template std::tuple> TransposeTuple(const std::tuple& t) { return std::make_tuple(Transpose(std::get<0>(t))); } template std::tuple, TransposeType> TransposeTuple( const std::tuple& t) { return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t))); } template std::tuple, TransposeType, TransposeType> TransposeTuple(const std::tuple& t) { return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), Transpose(std::get<2>(t))); } template std::tuple, TransposeType, TransposeType, TransposeType> TransposeTuple(const std::tuple& t) { return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), Transpose(std::get<2>(t)), Transpose(std::get<3>(t))); } template std::tuple, TransposeType, TransposeType, TransposeType, TransposeType> TransposeTuple(const std::tuple& t) { return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), Transpose(std::get<4>(t))); } template std::tuple, TransposeType, TransposeType, TransposeType, TransposeType, TransposeType> TransposeTuple(const std::tuple& t) { return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), Transpose(std::get<4>(t)), Transpose(std::get<5>(t))); } template void DispatchGemmShape(GemmContextType* context, const MatrixMap& lhs, const MatrixMap& rhs, MatrixMap* result, const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, const OutputPipelineType& output_pipeline) { assert(lhs.cols() == rhs.rows()); int rows = result->rows(); int cols = result->cols(); int depth = lhs.cols(); if (rows == 0 || cols == 0 || depth == 0) { // Vacuous GEMM, return early to avoid having to deal with // zero sizes below. return; } if (rows < cols) { auto transposed_result_map = Transpose(*result); return DispatchGemmShape( context, Transpose(rhs), Transpose(lhs), &transposed_result_map, Transpose(rhs_offset), Transpose(lhs_offset), TransposeTuple(output_pipeline)); } typedef DefaultKernel Kernel; MultiThreadGemm(context, Kernel(), lhs, rhs, result, lhs_offset, rhs_offset, output_pipeline); } } // end namespace gemmlowp #endif // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_