diff options
Diffstat (limited to 'eight_bit_int_gemm/eight_bit_int_gemm.cc')
-rw-r--r-- | eight_bit_int_gemm/eight_bit_int_gemm.cc | 77 |
1 files changed, 45 insertions, 32 deletions
diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.cc b/eight_bit_int_gemm/eight_bit_int_gemm.cc index ecea180..8113bf3 100644 --- a/eight_bit_int_gemm/eight_bit_int_gemm.cc +++ b/eight_bit_int_gemm/eight_bit_int_gemm.cc @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#endif #include "eight_bit_int_gemm.h" #include <memory> @@ -30,8 +33,14 @@ // is quite significant (approx. 200kb) which might be prohibitive in // low-memory situations. -#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) -#include "../meta/multi_thread_gemm.h" +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) +#include "../meta/legacy_multi_thread_gemm.h" +#else + +#if defined(GEMMLOWP_USE_META_FASTPATH) +#warning "META fast path turned on without NEON!" +#endif + #endif namespace gemmlowp { @@ -136,25 +145,31 @@ void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k, class Scratch { public: - Scratch() : buffer_(), size_(0) {} + Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {} void AssureSize(std::int32_t required_size) { if (size_ >= required_size) { return; } - buffer_.reset(new std::uint8_t[required_size]); + buffer_.reset(new std::uint8_t[required_size + 32]); + buffer_32_ = + buffer_.get() + + ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32); + assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0); size_ = required_size; } void Clear() { buffer_.reset(nullptr); + buffer_32_ = nullptr; size_ = 0; } - std::uint8_t* buffer() { return buffer_.get(); } + std::uint8_t* buffer() { return buffer_32_; } private: std::unique_ptr<std::uint8_t[]> buffer_; + std::uint8_t* buffer_32_; std::int32_t size_; }; @@ -172,7 +187,7 @@ void DestroyGlobalScratch() { global_scratch = nullptr; } -#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) { // Is it row major and nicely packed? @@ -205,8 +220,8 @@ bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) { bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c, int m, int n, int k, int lda, int ldb, int ldc, BitDepthSetting depth_setting) { - // Meta fastpath only supports 8bit x 8bit and k up to 2048. - if (depth_setting != BitDepthSetting::A8B8 || k > 2048) { + // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048. + if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) { return false; } @@ -242,20 +257,19 @@ void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs, std::int32_t shift, bool result_transpose, std::int32_t result_stride, std::uint8_t* result) { Scratch* scratch = GetOrCreateGlobalScratch(); + const std::int32_t max_num_threads = context->max_num_threads(); if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { - scratch->AssureSize( - meta::gemm_q8_scratch(m, n, k, context->max_num_threads())); - meta::multi_thread_gemm_q8( - context->workers_pool(), context->max_num_threads(), scratch->buffer(), - lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset, - multiplicative_offset, shift, result); + scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads)); + meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, + scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, + rhs_offset, sum_offset, multiplicative_offset, + shift, result); } else { - scratch->AssureSize( - meta::gemm_q8_scratch(n, m, k, context->max_num_threads())); - meta::multi_thread_gemm_q8( - context->workers_pool(), context->max_num_threads(), scratch->buffer(), - rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset, - multiplicative_offset, shift, result); + scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads)); + meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, + scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, + lhs_offset, sum_offset, multiplicative_offset, + shift, result); } } @@ -267,18 +281,17 @@ void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs, float result_offset, bool result_transpose, std::int32_t result_stride, float* result) { Scratch* scratch = GetOrCreateGlobalScratch(); + const std::int32_t max_num_threads = context->max_num_threads(); if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { - scratch->AssureSize( - meta::gemm_f_scratch(m, n, k, context->max_num_threads())); - meta::multi_thread_gemm_f( - context->workers_pool(), context->max_num_threads(), scratch->buffer(), - lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result); + scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads)); + meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, + scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, + rhs_offset, result_offset, result); } else { - scratch->AssureSize( - meta::gemm_f_scratch(n, m, k, context->max_num_threads())); - meta::multi_thread_gemm_f( - context->workers_pool(), context->max_num_threads(), scratch->buffer(), - rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result); + scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads)); + meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, + scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, + lhs_offset, result_offset, result); } } @@ -297,7 +310,7 @@ void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, AutoGlobalLock<EightBitIntGemmLockId> lock; GemmContext* context = GetOrCreateGlobalContext(); -#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, ldb, ldc, bit_depth)) { MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset, @@ -334,7 +347,7 @@ void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, AutoGlobalLock<EightBitIntGemmLockId> lock; GemmContext* context = GetOrCreateGlobalContext(); -#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32) +#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, ldb, ldc, bit_depth)) { MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset, |