aboutsummaryrefslogtreecommitdiff
path: root/eight_bit_int_gemm
diff options
context:
space:
mode:
authorMiao Wang <miaowang@google.com>2017-07-06 11:06:31 -0700
committerMiao Wang <miaowang@google.com>2017-07-06 20:20:07 +0000
commita9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1 (patch)
treec6c4113609be686ad91ff6710922d8416ca3be1c /eight_bit_int_gemm
parentb76ecb153a784e70e07bf7e19bcf3dfd6caec815 (diff)
downloadgemmlowp-a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1.tar.gz
Rebase gemmlowp to 36ffd29
Test: mm Test: build system image for sailfish Test: BLAS CTS tests pass Change-Id: I4cc9dbfd586f6653fc2d04e8e7ad78ada5d7dbe9
Diffstat (limited to 'eight_bit_int_gemm')
-rw-r--r--eight_bit_int_gemm/eight_bit_int_gemm.cc77
-rw-r--r--eight_bit_int_gemm/eight_bit_int_gemm.h4
2 files changed, 46 insertions, 35 deletions
diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.cc b/eight_bit_int_gemm/eight_bit_int_gemm.cc
index ecea180..8113bf3 100644
--- a/eight_bit_int_gemm/eight_bit_int_gemm.cc
+++ b/eight_bit_int_gemm/eight_bit_int_gemm.cc
@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+#endif
#include "eight_bit_int_gemm.h"
#include <memory>
@@ -30,8 +33,14 @@
// is quite significant (approx. 200kb) which might be prohibitive in
// low-memory situations.
-#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
-#include "../meta/multi_thread_gemm.h"
+#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
+#include "../meta/legacy_multi_thread_gemm.h"
+#else
+
+#if defined(GEMMLOWP_USE_META_FASTPATH)
+#warning "META fast path turned on without NEON!"
+#endif
+
#endif
namespace gemmlowp {
@@ -136,25 +145,31 @@ void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
class Scratch {
public:
- Scratch() : buffer_(), size_(0) {}
+ Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {}
void AssureSize(std::int32_t required_size) {
if (size_ >= required_size) {
return;
}
- buffer_.reset(new std::uint8_t[required_size]);
+ buffer_.reset(new std::uint8_t[required_size + 32]);
+ buffer_32_ =
+ buffer_.get() +
+ ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32);
+ assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0);
size_ = required_size;
}
void Clear() {
buffer_.reset(nullptr);
+ buffer_32_ = nullptr;
size_ = 0;
}
- std::uint8_t* buffer() { return buffer_.get(); }
+ std::uint8_t* buffer() { return buffer_32_; }
private:
std::unique_ptr<std::uint8_t[]> buffer_;
+ std::uint8_t* buffer_32_;
std::int32_t size_;
};
@@ -172,7 +187,7 @@ void DestroyGlobalScratch() {
global_scratch = nullptr;
}
-#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
+#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
// Is it row major and nicely packed?
@@ -205,8 +220,8 @@ bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
int m, int n, int k, int lda, int ldb, int ldc,
BitDepthSetting depth_setting) {
- // Meta fastpath only supports 8bit x 8bit and k up to 2048.
- if (depth_setting != BitDepthSetting::A8B8 || k > 2048) {
+ // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048.
+ if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) {
return false;
}
@@ -242,20 +257,19 @@ void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
std::int32_t shift, bool result_transpose,
std::int32_t result_stride, std::uint8_t* result) {
Scratch* scratch = GetOrCreateGlobalScratch();
+ const std::int32_t max_num_threads = context->max_num_threads();
if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
- scratch->AssureSize(
- meta::gemm_q8_scratch(m, n, k, context->max_num_threads()));
- meta::multi_thread_gemm_q8(
- context->workers_pool(), context->max_num_threads(), scratch->buffer(),
- lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset,
- multiplicative_offset, shift, result);
+ scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads));
+ meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
+ scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
+ rhs_offset, sum_offset, multiplicative_offset,
+ shift, result);
} else {
- scratch->AssureSize(
- meta::gemm_q8_scratch(n, m, k, context->max_num_threads()));
- meta::multi_thread_gemm_q8(
- context->workers_pool(), context->max_num_threads(), scratch->buffer(),
- rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset,
- multiplicative_offset, shift, result);
+ scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads));
+ meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
+ scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
+ lhs_offset, sum_offset, multiplicative_offset,
+ shift, result);
}
}
@@ -267,18 +281,17 @@ void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
float result_offset, bool result_transpose,
std::int32_t result_stride, float* result) {
Scratch* scratch = GetOrCreateGlobalScratch();
+ const std::int32_t max_num_threads = context->max_num_threads();
if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
- scratch->AssureSize(
- meta::gemm_f_scratch(m, n, k, context->max_num_threads()));
- meta::multi_thread_gemm_f(
- context->workers_pool(), context->max_num_threads(), scratch->buffer(),
- lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result);
+ scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads));
+ meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
+ scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
+ rhs_offset, result_offset, result);
} else {
- scratch->AssureSize(
- meta::gemm_f_scratch(n, m, k, context->max_num_threads()));
- meta::multi_thread_gemm_f(
- context->workers_pool(), context->max_num_threads(), scratch->buffer(),
- rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result);
+ scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads));
+ meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
+ scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
+ lhs_offset, result_offset, result);
}
}
@@ -297,7 +310,7 @@ void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
AutoGlobalLock<EightBitIntGemmLockId> lock;
GemmContext* context = GetOrCreateGlobalContext();
-#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
+#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
ldb, ldc, bit_depth)) {
MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
@@ -334,7 +347,7 @@ void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
AutoGlobalLock<EightBitIntGemmLockId> lock;
GemmContext* context = GetOrCreateGlobalContext();
-#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
+#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
ldb, ldc, bit_depth)) {
MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.h b/eight_bit_int_gemm/eight_bit_int_gemm.h
index 6bd9dfe..6bda427 100644
--- a/eight_bit_int_gemm/eight_bit_int_gemm.h
+++ b/eight_bit_int_gemm/eight_bit_int_gemm.h
@@ -24,8 +24,6 @@
namespace std {
using ::uint8_t;
using ::int32_t;
-using ::int64_t;
-using ::uint64_t;
}
#endif
@@ -46,7 +44,7 @@ namespace eight_bit_int_gemm {
// Users who prefer a state-less, singleton-less interface,
// should use the main gemmlowp interface (public/gemmlowp.h) instead.
-// The main entry point to compute a Gemm. This is the standard
+// The BitDepthSetting enum lists supported a/b bit-depth combinations.
enum class BitDepthSetting {
A8B8, // 8-bit a, 8-bit b
A5B7 // 5-bit a, 7-bit b