diff options
author | Miao Wang <miaowang@google.com> | 2017-07-06 11:06:31 -0700 |
---|---|---|
committer | Miao Wang <miaowang@google.com> | 2017-07-06 20:20:07 +0000 |
commit | a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1 (patch) | |
tree | c6c4113609be686ad91ff6710922d8416ca3be1c /eight_bit_int_gemm | |
parent | b76ecb153a784e70e07bf7e19bcf3dfd6caec815 (diff) | |
download | gemmlowp-a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1.tar.gz |
Rebase gemmlowp to 36ffd29
Test: mm
Test: build system image for sailfish
Test: BLAS CTS tests pass
Change-Id: I4cc9dbfd586f6653fc2d04e8e7ad78ada5d7dbe9
Diffstat (limited to 'eight_bit_int_gemm')
-rw-r--r-- | eight_bit_int_gemm/eight_bit_int_gemm.cc | 77 | ||||
-rw-r--r-- | eight_bit_int_gemm/eight_bit_int_gemm.h | 4 |
2 files changed, 46 insertions, 35 deletions
diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.cc b/eight_bit_int_gemm/eight_bit_int_gemm.cc index ecea180..8113bf3 100644 --- a/eight_bit_int_gemm/eight_bit_int_gemm.cc +++ b/eight_bit_int_gemm/eight_bit_int_gemm.cc @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#endif #include "eight_bit_int_gemm.h" #include <memory> @@ -30,8 +33,14 @@ // is quite significant (approx. 200kb) which might be prohibitive in // low-memory situations. -#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) -#include "../meta/multi_thread_gemm.h" +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) +#include "../meta/legacy_multi_thread_gemm.h" +#else + +#if defined(GEMMLOWP_USE_META_FASTPATH) +#warning "META fast path turned on without NEON!" +#endif + #endif namespace gemmlowp { @@ -136,25 +145,31 @@ void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k, class Scratch { public: - Scratch() : buffer_(), size_(0) {} + Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {} void AssureSize(std::int32_t required_size) { if (size_ >= required_size) { return; } - buffer_.reset(new std::uint8_t[required_size]); + buffer_.reset(new std::uint8_t[required_size + 32]); + buffer_32_ = + buffer_.get() + + ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32); + assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0); size_ = required_size; } void Clear() { buffer_.reset(nullptr); + buffer_32_ = nullptr; size_ = 0; } - std::uint8_t* buffer() { return buffer_.get(); } + std::uint8_t* buffer() { return buffer_32_; } private: std::unique_ptr<std::uint8_t[]> buffer_; + std::uint8_t* buffer_32_; std::int32_t size_; }; @@ -172,7 +187,7 @@ void DestroyGlobalScratch() { global_scratch = nullptr; } -#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) { // Is it row major and nicely packed? @@ -205,8 +220,8 @@ bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) { bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c, int m, int n, int k, int lda, int ldb, int ldc, BitDepthSetting depth_setting) { - // Meta fastpath only supports 8bit x 8bit and k up to 2048. - if (depth_setting != BitDepthSetting::A8B8 || k > 2048) { + // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048. + if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) { return false; } @@ -242,20 +257,19 @@ void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs, std::int32_t shift, bool result_transpose, std::int32_t result_stride, std::uint8_t* result) { Scratch* scratch = GetOrCreateGlobalScratch(); + const std::int32_t max_num_threads = context->max_num_threads(); if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { - scratch->AssureSize( - meta::gemm_q8_scratch(m, n, k, context->max_num_threads())); - meta::multi_thread_gemm_q8( - context->workers_pool(), context->max_num_threads(), scratch->buffer(), - lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset, - multiplicative_offset, shift, result); + scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads)); + meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, + scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, + rhs_offset, sum_offset, multiplicative_offset, + shift, result); } else { - scratch->AssureSize( - meta::gemm_q8_scratch(n, m, k, context->max_num_threads())); - meta::multi_thread_gemm_q8( - context->workers_pool(), context->max_num_threads(), scratch->buffer(), - rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset, - multiplicative_offset, shift, result); + scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads)); + meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, + scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, + lhs_offset, sum_offset, multiplicative_offset, + shift, result); } } @@ -267,18 +281,17 @@ void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs, float result_offset, bool result_transpose, std::int32_t result_stride, float* result) { Scratch* scratch = GetOrCreateGlobalScratch(); + const std::int32_t max_num_threads = context->max_num_threads(); if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { - scratch->AssureSize( - meta::gemm_f_scratch(m, n, k, context->max_num_threads())); - meta::multi_thread_gemm_f( - context->workers_pool(), context->max_num_threads(), scratch->buffer(), - lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result); + scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads)); + meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, + scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, + rhs_offset, result_offset, result); } else { - scratch->AssureSize( - meta::gemm_f_scratch(n, m, k, context->max_num_threads())); - meta::multi_thread_gemm_f( - context->workers_pool(), context->max_num_threads(), scratch->buffer(), - rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result); + scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads)); + meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, + scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, + lhs_offset, result_offset, result); } } @@ -297,7 +310,7 @@ void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, AutoGlobalLock<EightBitIntGemmLockId> lock; GemmContext* context = GetOrCreateGlobalContext(); -#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, ldb, ldc, bit_depth)) { MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset, @@ -334,7 +347,7 @@ void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, AutoGlobalLock<EightBitIntGemmLockId> lock; GemmContext* context = GetOrCreateGlobalContext(); -#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, ldb, ldc, bit_depth)) { MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset, diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.h b/eight_bit_int_gemm/eight_bit_int_gemm.h index 6bd9dfe..6bda427 100644 --- a/eight_bit_int_gemm/eight_bit_int_gemm.h +++ b/eight_bit_int_gemm/eight_bit_int_gemm.h @@ -24,8 +24,6 @@ namespace std { using ::uint8_t; using ::int32_t; -using ::int64_t; -using ::uint64_t; } #endif @@ -46,7 +44,7 @@ namespace eight_bit_int_gemm { // Users who prefer a state-less, singleton-less interface, // should use the main gemmlowp interface (public/gemmlowp.h) instead. -// The main entry point to compute a Gemm. This is the standard +// The BitDepthSetting enum lists supported a/b bit-depth combinations. enum class BitDepthSetting { A8B8, // 8-bit a, 8-bit b A5B7 // 5-bit a, 7-bit b |