diff options
author | Miao Wang <miaowang@google.com> | 2015-10-19 15:30:10 -0700 |
---|---|---|
committer | Miao Wang <miaowang@google.com> | 2016-02-03 11:19:49 -0800 |
commit | 7b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1 (patch) | |
tree | 5ff593d7c6f5798eaeb8708482c3af11b7021183 /eight_bit_int_gemm | |
parent | 963b3cb31bd43460e5879d9f70e2f0636183634e (diff) | |
download | gemmlowp-7b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1.tar.gz |
Rebase gemmlowp to e96c3a9android-n-preview-1
- Better multi-thread perf
- API change to match standard GEMM: C=A*B rather than C=B*A
Change-Id: I74159fcb246d2a1fc246015e221306bbe11ea8e3
Diffstat (limited to 'eight_bit_int_gemm')
-rw-r--r-- | eight_bit_int_gemm/Android.mk | 2 | ||||
-rw-r--r-- | eight_bit_int_gemm/eight_bit_int_gemm.cc | 305 | ||||
-rw-r--r-- | eight_bit_int_gemm/eight_bit_int_gemm.h | 7 |
3 files changed, 297 insertions, 17 deletions
diff --git a/eight_bit_int_gemm/Android.mk b/eight_bit_int_gemm/Android.mk index 663202c..4490262 100644 --- a/eight_bit_int_gemm/Android.mk +++ b/eight_bit_int_gemm/Android.mk @@ -40,6 +40,6 @@ endif LOCAL_CFLAGS += -std=c++11 LOCAL_CFLAGS += -DGEMMLOWP_USE_STLPORT LOCAL_C_INCLUDES += external/gemmlowp/ -LOCAL_NDK_STL_VARIANT := stlport_static +LOCAL_NDK_STL_VARIANT := c++_static include $(BUILD_STATIC_LIBRARY) diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.cc b/eight_bit_int_gemm/eight_bit_int_gemm.cc index 06202b5..ecea180 100644 --- a/eight_bit_int_gemm/eight_bit_int_gemm.cc +++ b/eight_bit_int_gemm/eight_bit_int_gemm.cc @@ -14,16 +14,28 @@ #include "eight_bit_int_gemm.h" +#include <memory> + // gemmlowp symbols should have hidden visibility. // currently this is ensured in the build system by // passing -finlines-visibility-hidden. TODO: it would be // safer to hardcode it here with some #pragma's. #include "../public/gemmlowp.h" -namespace gemmlowp { +// Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON +// code. This code path consists of a number of meta-programmed, automatically +// generated GEMM kernels that are suitable for some sizes of input matrices. +// Due to the fact that the generated code relies heavily on loop unrolling, +// inling and currying of runtime parameters the size of the generated binary +// is quite significant (approx. 200kb) which might be prohibitive in +// low-memory situations. -namespace eight_bit_int_gemm { +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) +#include "../meta/multi_thread_gemm.h" +#endif +namespace gemmlowp { +namespace eight_bit_int_gemm { namespace { // To be used as template parameter for GlobalLock. @@ -54,38 +66,224 @@ void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k, std::uint8_t* c, std::int32_t c_offset, std::int32_t c_mult_int, std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) { - const int lhs_offset = b_offset; - const int rhs_offset = a_offset; + const int lhs_offset = a_offset; + const int rhs_offset = b_offset; const int result_offset = c_offset; const int result_mult_int = c_mult_int; const int result_shift = c_shift; static const MapOrder ResultOrder = - transpose_c ? MapOrder::ColMajor : MapOrder::RowMajor; + transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; static const MapOrder LhsOrder = - transpose_b == transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; + transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; static const MapOrder RhsOrder = - transpose_a == transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; + transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; - MatrixMap<const std::uint8_t, LhsOrder> lhs(b, n, k, ldb); - MatrixMap<const std::uint8_t, RhsOrder> rhs(a, k, m, lda); - MatrixMap<std::uint8_t, ResultOrder> result(c, n, m, ldc); + MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda); + MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb); + MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc); switch (bit_depth) { -#define GEMMLOWP_HANDLE_BIT_DEPTH(AnBn, LnRn) \ - case BitDepthSetting::AnBn: \ - Gemm<std::uint8_t, gemmlowp::BitDepthSetting::LnRn>( \ +#define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ + case BitDepthSetting::BIT_DEPTH_SETTING: \ + Gemm<std::uint8_t, BIT_DEPTH_PARAMS>( \ context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \ result_mult_int, result_shift); \ return; - GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, L8R8) - GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, L7R5) + GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams) + GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams) default: abort(); #undef GEMMLOWP_HANDLE_BIT_DEPTH } } +template <bool transpose_a, bool transpose_b, bool transpose_c> +void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k, + const std::uint8_t* a, std::int32_t a_offset, + int lda, const std::uint8_t* b, + std::int32_t b_offset, int ldb, std::int32_t* c, + int ldc, BitDepthSetting bit_depth) { + const int lhs_offset = a_offset; + const int rhs_offset = b_offset; + + static const MapOrder ResultOrder = + transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; + static const MapOrder LhsOrder = + transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; + static const MapOrder RhsOrder = + transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; + + MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda); + MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb); + MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc); + + auto empty_pipeline = std::make_tuple(); + + switch (bit_depth) { +#define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ + case BitDepthSetting::BIT_DEPTH_SETTING: \ + GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>( \ + context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \ + return; + GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams) + GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams) + default: + abort(); +#undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32 + } +} + +class Scratch { + public: + Scratch() : buffer_(), size_(0) {} + + void AssureSize(std::int32_t required_size) { + if (size_ >= required_size) { + return; + } + buffer_.reset(new std::uint8_t[required_size]); + size_ = required_size; + } + + void Clear() { + buffer_.reset(nullptr); + size_ = 0; + } + + std::uint8_t* buffer() { return buffer_.get(); } + + private: + std::unique_ptr<std::uint8_t[]> buffer_; + std::int32_t size_; +}; + +Scratch* global_scratch = nullptr; + +Scratch* GetOrCreateGlobalScratch() { + if (global_scratch == nullptr) { + global_scratch = new Scratch(); + } + return global_scratch; +} + +void DestroyGlobalScratch() { + delete global_scratch; + global_scratch = nullptr; +} + +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) + +bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) { + // Is it row major and nicely packed? + if (transpose && stride == cols) { + return true; + } + + // Is it a one row vector? (a vector is both row and column major) + if (rows == 1) { + return true; + } + + return false; +} + +bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) { + // Is it column major and nicely packed? + if (!transpose && stride == rows) { + return true; + } + + // Is it a one column vector? (a vector is both row and column major) + if (cols == 1) { + return true; + } + + return false; +} + +bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c, + int m, int n, int k, int lda, int ldb, int ldc, + BitDepthSetting depth_setting) { + // Meta fastpath only supports 8bit x 8bit and k up to 2048. + if (depth_setting != BitDepthSetting::A8B8 || k > 2048) { + return false; + } + + // The first operand needs to be a row major matrix or a vector. + if (!IsRowMajorOrVector(transpose_a, lda, m, k)) { + return false; + } + + // The second operand needs to be a column major matrix or a vector. + if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) { + return false; + } + + // The result can either be a row major matrix, a column major matrix or + // a vector. + if (IsRowMajorOrVector(transpose_c, ldc, m, n)) { + return true; + } + + if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) { + return true; + } + + return false; +} + +// Assure enough scratch memory is allocated and run the fast path gemm. +void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs, + const std::uint8_t* rhs, int m, int n, int k, + std::int32_t lhs_offset, std::int32_t rhs_offset, + std::int32_t sum_offset, + std::int32_t multiplicative_offset, + std::int32_t shift, bool result_transpose, + std::int32_t result_stride, std::uint8_t* result) { + Scratch* scratch = GetOrCreateGlobalScratch(); + if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { + scratch->AssureSize( + meta::gemm_q8_scratch(m, n, k, context->max_num_threads())); + meta::multi_thread_gemm_q8( + context->workers_pool(), context->max_num_threads(), scratch->buffer(), + lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset, + multiplicative_offset, shift, result); + } else { + scratch->AssureSize( + meta::gemm_q8_scratch(n, m, k, context->max_num_threads())); + meta::multi_thread_gemm_q8( + context->workers_pool(), context->max_num_threads(), scratch->buffer(), + rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset, + multiplicative_offset, shift, result); + } +} + +// Assure enough scratch memory is allocated and run the 8bit to float fast +// path gemm. +void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs, + const std::uint8_t* rhs, int m, int n, int k, + std::int32_t lhs_offset, std::int32_t rhs_offset, + float result_offset, bool result_transpose, + std::int32_t result_stride, float* result) { + Scratch* scratch = GetOrCreateGlobalScratch(); + if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { + scratch->AssureSize( + meta::gemm_f_scratch(m, n, k, context->max_num_threads())); + meta::multi_thread_gemm_f( + context->workers_pool(), context->max_num_threads(), scratch->buffer(), + lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result); + } else { + scratch->AssureSize( + meta::gemm_f_scratch(n, m, k, context->max_num_threads())); + meta::multi_thread_gemm_f( + context->workers_pool(), context->max_num_threads(), scratch->buffer(), + rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result); + } +} + +#endif + } // end anonymous namespace // Public interface entry points @@ -99,6 +297,15 @@ void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, AutoGlobalLock<EightBitIntGemmLockId> lock; GemmContext* context = GetOrCreateGlobalContext(); +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) + if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, + ldb, ldc, bit_depth)) { + MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset, + c_mult_int, c_shift, transpose_c, ldc, c); + return; + } +#endif + #define GEMMLOWP_HANDLE_CASE(ta, tb, tc) \ if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b, \ @@ -118,6 +325,72 @@ void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, #undef GEMMLOWP_HANDLE_CASE } +void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, + int m, int n, int k, const std::uint8_t* a, + std::int32_t a_offset, std::int32_t lda, + const std::uint8_t* b, std::int32_t b_offset, + std::int32_t ldb, float* c, float c_offset, + std::int32_t ldc, BitDepthSetting bit_depth) { + AutoGlobalLock<EightBitIntGemmLockId> lock; + GemmContext* context = GetOrCreateGlobalContext(); + +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) + if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, + ldb, ldc, bit_depth)) { + MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset, + transpose_c, ldc, c); + return; + } +#endif + + // TODO(maciekc): implement a float output stage, get rid of scratch memory. + Scratch* scratch = GetOrCreateGlobalScratch(); + if (transpose_c) { + scratch->AssureSize(m * ldc * sizeof(std::int32_t)); + } else { + scratch->AssureSize(n * ldc * sizeof(std::int32_t)); + } + std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer()); + +#define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc) \ + if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ + EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \ + b, b_offset, ldb, temp_c, ldc, \ + bit_depth); \ + } + + GEMMLOWP_HANDLE_INT32_CASE(false, false, false) + GEMMLOWP_HANDLE_INT32_CASE(false, false, true) + GEMMLOWP_HANDLE_INT32_CASE(false, true, false) + GEMMLOWP_HANDLE_INT32_CASE(false, true, true) + GEMMLOWP_HANDLE_INT32_CASE(true, false, false) + GEMMLOWP_HANDLE_INT32_CASE(true, false, true) + GEMMLOWP_HANDLE_INT32_CASE(true, true, false) + GEMMLOWP_HANDLE_INT32_CASE(true, true, true) + +#undef GEMMLOWP_HANDLE_INT32_CASE + + if (transpose_c) { + // Row major. + for (int i = 0; i < m; ++i) { + float* dest_row = c + i * ldc; + std::int32_t* src_row = temp_c + i * ldc; + for (int j = 0; j < n; ++j) { + dest_row[j] = static_cast<float>(src_row[j]) * c_offset; + } + } + } else { + // Column major. + for (int i = 0; i < n; ++i) { + float* dest_column = c + i * ldc; + std::int32_t* src_column = temp_c + i * ldc; + for (int j = 0; j < m; ++j) { + dest_column[j] = static_cast<float>(src_column[j]) * c_offset; + } + } + } +} + void SetMaxNumThreads(int n) { AutoGlobalLock<EightBitIntGemmLockId> lock; GemmContext* context = GetOrCreateGlobalContext(); @@ -127,8 +400,8 @@ void SetMaxNumThreads(int n) { void FreePersistentResources() { AutoGlobalLock<EightBitIntGemmLockId> lock; DestroyGlobalContext(); + DestroyGlobalScratch(); } } // namespace eight_bit_int_gemm - } // namespace gemmlowp diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.h b/eight_bit_int_gemm/eight_bit_int_gemm.h index 875a1da..6bd9dfe 100644 --- a/eight_bit_int_gemm/eight_bit_int_gemm.h +++ b/eight_bit_int_gemm/eight_bit_int_gemm.h @@ -25,6 +25,7 @@ namespace std { using ::uint8_t; using ::int32_t; using ::int64_t; +using ::uint64_t; } #endif @@ -60,6 +61,12 @@ void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, std::int32_t c_offset, std::int32_t c_mult_int, std::int32_t c_shift, int ldc, BitDepthSetting bit_depth); +void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, + int m, int n, int k, const std::uint8_t *a, + std::int32_t a_offset, int lda, const std::uint8_t *b, + std::int32_t b_offset, int ldb, float *c, float c_offset, + int ldc, BitDepthSetting bit_depth); + // Frees any persistent resources // (threads, thread pools, allocators, buffers, ...) // that gemmlowp might hold. This is called automatically |