aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMiao Wang <miaowang@google.com>2019-08-27 17:50:04 -0700
committerandroid-build-merger <android-build-merger@google.com>2019-08-27 17:50:04 -0700
commit68dcc597e650eeda114b77f22e6391e85f4c5437 (patch)
treef910ae75e271bc79a22d9a73b94da6bdbe01d330
parent032b4e313c03a94aac1769c0a81ae7d49943fe4b (diff)
parente8a1111f830a39e429ecbab08972c370fe9dcfb0 (diff)
downloadgemmlowp-68dcc597e650eeda114b77f22e6391e85f4c5437.tar.gz
Rebase gemmlowp to a227af1fdb47f250b5df07d6936366b0f8113b65 am: 70ba50cbca am: 36f90a2b7a am: 846c903a24
am: e8a1111f83 Change-Id: Ibd4d3c2ce93aec1001c278e313f98770d6b6676d
-rw-r--r--doc/kernel.md14
-rw-r--r--doc/public.md4
-rw-r--r--doc/quantization.md2
-rw-r--r--eight_bit_int_gemm/eight_bit_int_gemm.cc3
-rw-r--r--fixedpoint/fixedpoint.h68
-rw-r--r--fixedpoint/fixedpoint_avx.h218
-rw-r--r--fixedpoint/fixedpoint_msa.h75
-rw-r--r--fixedpoint/fixedpoint_neon.h30
-rw-r--r--fixedpoint/fixedpoint_sse.h4
-rw-r--r--internal/common.h142
-rw-r--r--internal/detect_platform.h166
-rw-r--r--internal/dispatch_gemm_shape.h16
-rw-r--r--internal/kernel.h25
-rw-r--r--internal/kernel_avx.h361
-rw-r--r--internal/kernel_default.h70
-rw-r--r--internal/kernel_msa.h488
-rw-r--r--internal/kernel_neon.h295
-rw-r--r--internal/kernel_sse.h2
-rw-r--r--internal/multi_thread_gemm.h298
-rw-r--r--internal/output.h92
-rw-r--r--internal/output_avx.h19
-rw-r--r--internal/output_msa.h613
-rw-r--r--internal/output_neon.h233
-rw-r--r--internal/pack.h22
-rw-r--r--internal/pack_avx.h282
-rw-r--r--internal/pack_msa.h78
-rw-r--r--internal/pack_neon.h64
-rw-r--r--internal/platform.h5
-rw-r--r--internal/simd_wrappers.h155
-rw-r--r--internal/simd_wrappers_common_neon_sse.h204
-rw-r--r--internal/simd_wrappers_msa.h11
-rw-r--r--internal/simd_wrappers_neon.h370
-rw-r--r--internal/unpack.h6
-rw-r--r--meta/multi_thread_common.h6
-rw-r--r--profiling/instrumentation.h11
-rw-r--r--profiling/pthread_everywhere.h3
-rw-r--r--public/bit_depth.h10
-rw-r--r--public/map.h1
-rw-r--r--public/output_stages.h32
39 files changed, 3938 insertions, 560 deletions
diff --git a/doc/kernel.md b/doc/kernel.md
index 261cb92..f3f2138 100644
--- a/doc/kernel.md
+++ b/doc/kernel.md
@@ -40,11 +40,15 @@ NEONKernel12x4Depth2 kernel, which specifies its format as
The meaning of these terms is explained in the lengthy comment at the top of
internal/kernel.h. Here, they mean that this kernel handles at each iteration
-(along the depth dimension): - 3 'cells' of size 4x2 each of the lhs, so a total
-lhs block of size 12x2 - 1 'cell' of size 2x4 of the rhs. In other words, this
-kernel handles 12 rows of the lhs and 4 columns of the rhs, and handles two
-levels of depth at once. The 'cells' and `CellFormat` detail the layout of these
-12x2 and 2x4 blocks.
+(along the depth dimension):
+
+- 3 'cells' of size 4x2 each of the lhs, so a total lhs block of size 12x2
+
+- 1 'cell' of size 2x4 of the rhs.
+
+In other words, this kernel handles 12 rows of the lhs and 4 columns of the
+rhs, and handles two levels of depth at once. The 'cells' and `CellFormat`
+detail the layout of these 12x2 and 2x4 blocks.
This kernel then loads these 12x2 and 2x4 blocks and computes the corresponding
12x4 GEMM; for ease of reference let us paste the critical comment and code
diff --git a/doc/public.md b/doc/public.md
index 935f6db..7739b85 100644
--- a/doc/public.md
+++ b/doc/public.md
@@ -14,7 +14,7 @@ The high-level overview of how this specifies a low-precision matrix
multiplication is explained in [low-precision.md](low-precision.md). The
rationale for a specific quantization paradigm is given in
[quantization.md](quantization.md). That specific quantization paradigm is
-implemented at two different stages of the computation: as pre-processing ont
+implemented at two different stages of the computation: as pre-processing on
the operands and as post-processing on the result:
* Pre-processing on the LHS, RHS operands, in the form of adding constant
@@ -56,7 +56,7 @@ being automatically deduced from function parameters:
* `InputScalar`: The scalar type of the LHS and RHS operands. At the moment,
this must be `std::uint8_t`.
-* `OutputScalar`: The scalar type of the LHS and RHS operands. At the moment,
+* `OutputScalar`: The scalar type of the result. At the moment,
this must be `std::uint8_t`.
* `BitDepthParams`: Defines the bit format of the input and output matrices
and the required accuracy of the computation. At the moment, the only
diff --git a/doc/quantization.md b/doc/quantization.md
index 3a8f72b..e5055e7 100644
--- a/doc/quantization.md
+++ b/doc/quantization.md
@@ -13,7 +13,7 @@ quantization paradigm affects the calculations that gemmlowp itself needs to
perform, specifically, it affects how one goes from internal 32bit accumulator
to final 8bit outputs.
-The part of gemmlowp transforming internal internal 32bit accumulator to final
+The part of gemmlowp transforming internal 32bit accumulator to final
8bit outputs is the "output pipeline" described in [output.md](output.md).
gemmlowp's `GemmWithOutputPipeline` entry point allows specifying an arbitrary
diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.cc b/eight_bit_int_gemm/eight_bit_int_gemm.cc
index 512c483..a8d9b43 100644
--- a/eight_bit_int_gemm/eight_bit_int_gemm.cc
+++ b/eight_bit_int_gemm/eight_bit_int_gemm.cc
@@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#endif
#include "eight_bit_int_gemm.h"
#include <memory>
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h
index d39341b..58e8050 100644
--- a/fixedpoint/fixedpoint.h
+++ b/fixedpoint/fixedpoint.h
@@ -18,10 +18,13 @@
#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
#define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
+#include <algorithm>
#include <cassert>
+#include <cmath>
+#include <cstdint>
#include <limits>
-#include "../internal/common.h"
+#include "../internal/detect_platform.h"
namespace gemmlowp {
@@ -47,13 +50,13 @@ struct FixedPointRawTypeTraits {};
template <>
struct FixedPointRawTypeTraits<std::int32_t> {
typedef std::int32_t ScalarRawType;
- static const int kLanes = 1;
+ static constexpr int kLanes = 1;
};
template <>
struct FixedPointRawTypeTraits<std::int16_t> {
typedef std::int16_t ScalarRawType;
- static const int kLanes = 1;
+ static constexpr int kLanes = 1;
};
// Returns a SIMD value duplicating a scalar value across all lanes.
@@ -109,11 +112,25 @@ tIntegerType Neg(tIntegerType a) {
return -a;
}
-// Integer arithmetic left-shift, equivalent to multiplying with a
-// power of two. Not saturating. Overflow is undefined behavior.
-template <typename tIntegerType>
-tIntegerType ShiftLeft(tIntegerType a, int offset) {
- return a << offset;
+// Integer arithmetic left-shift, equivalent to multiplying with a power of two.
+// Negative values are OK. In case of overflow, no Undefined
+// Behavior, but the results are implementation-defined (in practice,
+// they currently are saturated, but we make no commitment to that). The idea
+// is that the caller will want to implement the overflowing cases with
+// saturation with compare-and-mask, so we don't care about the results
+// in the overflow case, we just want to avoid undefined behavior.
+//
+// tIntegerType may be int32 or any narrower signed type.
+template <typename tIntegerType, typename OffsetType>
+tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
+ const std::int64_t wide_a = static_cast<std::int64_t>(a);
+ const std::int64_t wide_shifted = wide_a * (1 << offset);
+ const auto min = std::numeric_limits<tIntegerType>::min();
+ const auto max = std::numeric_limits<tIntegerType>::max();
+ return wide_shifted < min
+ ? min
+ : wide_shifted > max ? max
+ : static_cast<tIntegerType>(wide_shifted);
}
// Integer arithmetic right-shift. Not rounding.
@@ -137,7 +154,7 @@ tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
// input scalar is non-zero.
template <typename tIntegerType>
tIntegerType MaskIfNonZero(tIntegerType a) {
- static const tIntegerType zero = 0;
+ static constexpr tIntegerType zero = 0;
return a ? BitNot(zero) : zero;
}
@@ -211,6 +228,7 @@ bool Any(tIntegerType a) {
template <typename IntegerType>
IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ (void)b;
return a;
}
@@ -235,6 +253,7 @@ inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
template <typename IntegerType>
IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ (void)b;
return a;
}
@@ -244,7 +263,9 @@ inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
std::int32_t a32 = a;
std::int32_t b32 = b;
std::int32_t sum = a32 + b32;
- return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
+ return static_cast<std::int16_t>(
+ std::min(static_cast<std::int32_t>(32767),
+ std::max(static_cast<std::int32_t>(-32768), sum)));
}
// Returns a+b, saturating if the integers are 16bit or narrower,
@@ -298,6 +319,7 @@ IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
template <typename IntegerType>
IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ (void)b;
return a;
}
@@ -331,8 +353,8 @@ inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
// Correctly-rounded-to-nearest division by a power-of-two.
// Also known as a rounding arithmetic right shift.
-template <typename IntegerType>
-inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
+template <typename IntegerType, typename ExponentType>
+inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
assert(exponent >= 0);
assert(exponent <= 31);
const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
@@ -432,9 +454,9 @@ class FixedPoint {
typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
- static const int kTotalBits = 8 * sizeof(ScalarRawType);
- static const int kIntegerBits = tIntegerBits;
- static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
+ static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
+ static constexpr int kIntegerBits = tIntegerBits;
+ static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
"bad IntegerBits");
@@ -474,7 +496,7 @@ class FixedPoint {
template <int Exponent>
static FixedPoint ConstantPOT() {
- static const int kOffset = kFractionalBits + Exponent;
+ static constexpr int kOffset = kFractionalBits + Exponent;
static_assert(
kOffset < 31,
"Constant not exactly representable in this fixed-point format");
@@ -645,7 +667,7 @@ double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
FixedPoint<tRawType, tIntegerBitsDst> Rescale(
FixedPoint<tRawType, tIntegerBitsSrc> x) {
- static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
+ static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
FixedPoint<tRawType, tIntegerBitsDst> result;
result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
return result;
@@ -725,9 +747,9 @@ FixedPoint<tRawType, 0> exp_on_negative_values(
FixedPoint<tRawType, tIntegerBits> a) {
typedef FixedPoint<tRawType, tIntegerBits> InputF;
typedef FixedPoint<tRawType, 0> ResultF;
- static const int kFractionalBits = InputF::kFractionalBits;
- static const int kIntegerBits = InputF::kIntegerBits;
- static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
+ static constexpr int kFractionalBits = InputF::kFractionalBits;
+ static constexpr int kIntegerBits = InputF::kIntegerBits;
+ const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
@@ -755,10 +777,10 @@ FixedPoint<tRawType, 0> exp_on_negative_values(
#undef GEMMLOWP_EXP_BARREL_SHIFTER
+ static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
if (kIntegerBits > 5) {
- static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
const InputF clamp =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
}
@@ -867,6 +889,8 @@ FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
#ifdef GEMMLOWP_NEON
#include "./fixedpoint_neon.h"
+#elif defined(GEMMLOWP_AVX2)
+#include "./fixedpoint_avx.h"
#elif defined(GEMMLOWP_SSE4)
#include "./fixedpoint_sse.h"
#elif defined(GEMMLOWP_MSA)
diff --git a/fixedpoint/fixedpoint_avx.h b/fixedpoint/fixedpoint_avx.h
new file mode 100644
index 0000000..1816386
--- /dev/null
+++ b/fixedpoint/fixedpoint_avx.h
@@ -0,0 +1,218 @@
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// fixedpoint_avx.h: optimized avx specializations of the templates
+// in fixedpoint.h.
+
+#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
+#define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
+
+#include <smmintrin.h>
+#include "fixedpoint.h"
+#include "fixedpoint_sse.h"
+
+namespace gemmlowp {
+
+template <>
+struct FixedPointRawTypeTraits<__m256i> {
+ typedef std::int32_t ScalarRawType;
+ static const int kLanes = 4;
+};
+
+template <>
+inline __m256i BitAnd(__m256i a, __m256i b) {
+ return _mm256_and_si256(a, b);
+}
+
+template <>
+inline __m256i BitOr(__m256i a, __m256i b) {
+ return _mm256_or_si256(a, b);
+}
+
+template <>
+inline __m256i BitXor(__m256i a, __m256i b) {
+ return _mm256_xor_si256(a, b);
+}
+
+template <>
+inline __m256i BitNot(__m256i a) {
+ return _mm256_andnot_si256(a, _mm256_set1_epi32(-1));
+}
+
+template <>
+inline __m256i Add(__m256i a, __m256i b) {
+ return _mm256_add_epi32(a, b);
+}
+
+template <>
+inline __m256i Mul(__m256i a, __m256i b) {
+ return _mm256_mullo_epi32(a, b);
+}
+
+template <>
+inline __m256i Sub(__m256i a, __m256i b) {
+ return _mm256_sub_epi32(a, b);
+}
+
+template <>
+inline __m256i Neg(__m256i a) {
+ return _mm256_sign_epi32(a, _mm256_set1_epi32(-1));
+}
+
+template <>
+inline __m256i ShiftLeft(__m256i a, int offset) {
+ return _mm256_slli_epi32(a, offset);
+}
+
+template <>
+inline __m256i ShiftRight(__m256i a, int offset) {
+ return _mm256_srai_epi32(a, offset);
+}
+
+template <>
+inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val,
+ __m256i else_val) {
+ return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(else_val),
+ _mm256_castsi256_ps(then_val),
+ _mm256_castsi256_ps(if_mask)));
+}
+
+template <>
+inline __m256i MaskIfEqual(__m256i a, __m256i b) {
+ return _mm256_cmpeq_epi32(a, b);
+}
+
+template <>
+inline __m256i MaskIfNotEqual(__m256i a, __m256i b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
+inline __m256i MaskIfZero(__m256i a) {
+ return MaskIfEqual(a, _mm256_set1_epi32(0));
+}
+
+template <>
+inline __m256i MaskIfNonZero(__m256i a) {
+ return MaskIfNotEqual(a, _mm256_set1_epi32(0));
+}
+
+template <>
+inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) {
+ return _mm256_cmpgt_epi32(a, b);
+}
+
+template <>
+inline __m256i MaskIfLessThan(__m256i a, __m256i b) {
+ return _mm256_cmpgt_epi32(b, a);
+}
+
+template <>
+inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) {
+ return BitNot(MaskIfLessThan(a, b));
+}
+
+template <>
+inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) {
+ return BitNot(MaskIfGreaterThan(a, b));
+}
+
+/* Assumptions:
+ - All and Any are used on masks.
+ - masks are all_ones for true lanes, all_zeroes otherwise.
+Hence, All means all 128bits set, and Any means any bit set.
+*/
+
+template <>
+inline bool All(__m256i a) {
+ return _mm256_testc_si256(a, a);
+}
+
+template <>
+inline bool Any(__m256i a) {
+ return BitNot(_mm256_testz_si256(a, a));
+}
+
+template <>
+inline __m256i RoundingHalfSum(__m256i a, __m256i b) {
+ /* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */
+ /* We divide the inputs before the add to avoid the overflow and costly test
+ */
+ /* of checking if an overflow occured on signed add */
+ /* round_bit_mask = _mm_set1_epi32(1); */
+ /* a_over_2 = _mm_srai_epi32(a, 1); */
+ /* b_over_2 = _mm_srai_epi32(b, 1); */
+ /* sum = Add(a_over_2, b_over_2); */
+ /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */
+ /* return Add(sum, round_bit); */
+
+ /* Other possibility detecting overflow and xor the sign if an overflow
+ * happened*/
+ __m256i one, sign_bit_mask, sum, rounded_half_sum, overflow, result;
+ one = _mm256_set1_epi32(1);
+ sign_bit_mask = _mm256_set1_epi32(0x80000000);
+ sum = Add(a, b);
+ rounded_half_sum = _mm256_srai_epi32(Add(sum, one), 1);
+ overflow =
+ BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)),
+ sign_bit_mask);
+ result = BitXor(rounded_half_sum, overflow);
+ return result;
+}
+
+template <>
+inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) {
+ __m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
+ __m256i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
+ __m256i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result;
+ __m256i nudge;
+
+ // saturation only happen if a == b == INT_MIN
+ min = _mm256_set1_epi32(std::numeric_limits<std::int32_t>::min());
+ saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min));
+
+ // a = a0 | a1 | a2 | a3
+ // b = b0 | b1 | b2 | b3
+ a0_a2 = a;
+ a1_a3 = _mm256_srli_si256(a, 4);
+ b0_b2 = b;
+ b1_b3 = _mm256_srli_si256(b, 4);
+
+ a0b0_a2b2 = _mm256_mul_epi32(a0_a2, b0_b2);
+ a1b1_a3b3 = _mm256_mul_epi32(a1_a3, b1_b3);
+
+ // do the rounding and take into account that it will be doubled
+ nudge = _mm256_set1_epi64x(1 << 30);
+ a0b0_a2b2_rounded = _mm256_add_epi64(a0b0_a2b2, nudge);
+ a1b1_a3b3_rounded = _mm256_add_epi64(a1b1_a3b3, nudge);
+
+ // do the doubling
+ a0b0_a2b2_rounded_2x = _mm256_slli_epi64(a0b0_a2b2_rounded, 1);
+ a1b1_a3b3_rounded_2x = _mm256_slli_epi64(a1b1_a3b3_rounded, 1);
+
+ // get the high part of the products
+ result = _mm256_blend_epi16(_mm256_srli_si256(a0b0_a2b2_rounded_2x, 4),
+ a1b1_a3b3_rounded_2x, 0xcc);
+
+ // saturate those which overflowed
+ return SelectUsingMask(saturation_mask, min, result);
+}
+
+template <>
+inline __m256i Dup<__m256i>(std::int32_t x) {
+ return _mm256_set1_epi32(x);
+}
+
+} // end namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h
index c7a110c..b17f32a 100644
--- a/fixedpoint/fixedpoint_msa.h
+++ b/fixedpoint/fixedpoint_msa.h
@@ -25,13 +25,13 @@ namespace gemmlowp {
template <>
struct FixedPointRawTypeTraits<v4i32> {
typedef std::int32_t ScalarRawType;
- static const int kLanes = 4;
+ static constexpr int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<v8i16> {
typedef std::int16_t ScalarRawType;
- static const int kLanes = 8;
+ static constexpr int kLanes = 8;
};
template <>
@@ -326,11 +326,71 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> {
}
};
-// TODO: possibly implement:
-// template <> v4i32 RoundingDivideByPOT(v4i32, int)
-// template <> v8i16 RoundingDivideByPOT(v8i16, int)
-// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1>
-// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1>
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> {
+ static v4i32 eval(v4i32 x) {
+ static_assert(-31 <= Exponent && Exponent <= -1, "");
+ // Isolate the sign bits.
+ v4i32 sign = __builtin_msa_srli_w(x, 31);
+ // Decrement the negative elements by 1 (with saturation).
+ x = __builtin_msa_subs_s_w(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srari instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srari_w(x, -Exponent);
+ }
+};
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> {
+ static v8i16 eval(v8i16 x) {
+ static_assert(-15 <= Exponent && Exponent <= -1, "");
+ // Isolate the sign bits.
+ v8i16 sign = __builtin_msa_srli_h(x, 15);
+ // Decrement the negative elements by 1 (with saturation).
+ x = __builtin_msa_subs_s_h(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srari instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srari_h(x, -Exponent);
+ }
+};
+
+template <>
+inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) {
+ v4i32 e = __builtin_msa_fill_w(exponent);
+ // Isolate the sign bits.
+ v4i32 sign = __builtin_msa_srli_w(x, 31);
+ // Reset them to 0 if exponent is 0.
+ sign = __builtin_msa_min_s_w(sign, e);
+ // Decrement the negative elements by 1 (with saturation)
+ // if exponent is non-zero.
+ x = __builtin_msa_subs_s_w(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srar instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srar_w(x, e);
+}
+
+template <>
+inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) {
+ v8i16 e = __builtin_msa_fill_h(exponent);
+ // Isolate the sign bits.
+ v8i16 sign = __builtin_msa_srli_h(x, 15);
+ // Reset them to 0 if exponent is 0.
+ sign = __builtin_msa_min_s_h(sign, e);
+ // Decrement the negative elements by 1 (with saturation)
+ // if exponent is non-zero.
+ x = __builtin_msa_subs_s_h(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srar instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srar_h(x, e);
+}
template <>
inline v4i32 Dup<v4i32>(std::int32_t x) {
@@ -346,7 +406,6 @@ inline v8i16 Dup<v8i16>(std::int16_t x) {
template <>
inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) {
return __builtin_msa_adds_s_h(a, b);
- return a;
}
} // end namespace gemmlowp
diff --git a/fixedpoint/fixedpoint_neon.h b/fixedpoint/fixedpoint_neon.h
index 92b349b..4dab6c9 100644
--- a/fixedpoint/fixedpoint_neon.h
+++ b/fixedpoint/fixedpoint_neon.h
@@ -25,13 +25,13 @@ namespace gemmlowp {
template <>
struct FixedPointRawTypeTraits<int32x4_t> {
typedef std::int32_t ScalarRawType;
- static const int kLanes = 4;
+ static constexpr int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<int16x8_t> {
typedef std::int16_t ScalarRawType;
- static const int kLanes = 8;
+ static constexpr int kLanes = 8;
};
template <>
@@ -115,6 +115,16 @@ inline int16x8_t ShiftLeft(int16x8_t a, int offset) {
}
template <>
+inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) {
+ return vshlq_s32(a, offset);
+}
+
+template <>
+inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) {
+ return vshlq_s16(a, offset);
+}
+
+template <>
inline int32x4_t ShiftRight(int32x4_t a, int offset) {
return vshlq_s32(a, vdupq_n_s32(-offset));
}
@@ -282,6 +292,22 @@ inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) {
return vrshlq_s16(fixed_up_x, shift_vec);
}
+template <>
+inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) {
+ const int32x4_t shift_vec = vnegq_s32(exponent);
+ const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
+ const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
+ return vrshlq_s32(fixed_up_x, shift_vec);
+}
+
+template <>
+inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) {
+ const int16x8_t shift_vec = vnegq_s16(exponent);
+ const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
+ const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
+ return vrshlq_s16(fixed_up_x, shift_vec);
+}
+
template <int Exponent>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h
index ba990f0..a1fae32 100644
--- a/fixedpoint/fixedpoint_sse.h
+++ b/fixedpoint/fixedpoint_sse.h
@@ -42,13 +42,13 @@ struct int16x8_m128i {
template <>
struct FixedPointRawTypeTraits<__m128i> {
typedef std::int32_t ScalarRawType;
- static const int kLanes = 4;
+ static constexpr int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<int16x8_m128i> {
typedef std::int16_t ScalarRawType;
- static const int kLanes = 8;
+ static constexpr int kLanes = 8;
};
template <>
diff --git a/internal/common.h b/internal/common.h
index 26b6713..332ad07 100644
--- a/internal/common.h
+++ b/internal/common.h
@@ -26,144 +26,9 @@
#include <cmath>
#include <cstdlib>
+#include "../internal/detect_platform.h"
#include "../profiling/instrumentation.h"
-// Our inline assembly path assume GCC/Clang syntax.
-// Native Client doesn't seem to support inline assembly(?).
-#if defined(__GNUC__) && !defined(__native_client__)
-#define GEMMLOWP_ALLOW_INLINE_ASM
-#endif
-
-// Define macro statement that avoids inlining for GCC.
-// For non-GCC, define as empty macro.
-#if defined(__GNUC__)
-#define GEMMLOWP_NOINLINE __attribute__((noinline))
-#else
-#define GEMMLOWP_NOINLINE
-#endif
-
-// Detect ARM, 32-bit or 64-bit
-#ifdef __arm__
-#define GEMMLOWP_ARM_32
-#endif
-
-#ifdef __aarch64__
-#define GEMMLOWP_ARM_64
-#endif
-
-#if defined(GEMMLOWP_ARM_32) || defined(GEMMLOWP_ARM_64)
-#define GEMMLOWP_ARM
-#endif
-
-// Detect MIPS, 32-bit or 64-bit
-#if defined(__mips) && !defined(__LP64__)
-#define GEMMLOWP_MIPS_32
-#endif
-
-#if defined(__mips) && defined(__LP64__)
-#define GEMMLOWP_MIPS_64
-#endif
-
-#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64)
-#define GEMMLOWP_MIPS
-#endif
-
-// Detect x86, 32-bit or 64-bit
-#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386)
-#define GEMMLOWP_X86_32
-#endif
-
-#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64)
-#define GEMMLOWP_X86_64
-#endif
-
-#if defined(GEMMLOWP_X86_32) || defined(GEMMLOWP_X86_64)
-#define GEMMLOWP_X86
-#endif
-
-// Some of our optimized paths use inline assembly and for
-// now we don't bother enabling some other optimized paths using intrinddics
-// where we can't use inline assembly paths.
-#ifdef GEMMLOWP_ALLOW_INLINE_ASM
-
-// Detect NEON. It's important to check for both tokens.
-#if (defined __ARM_NEON) || (defined __ARM_NEON__)
-#define GEMMLOWP_NEON
-#endif
-
-// Convenience NEON tokens for 32-bit or 64-bit
-#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_32)
-#define GEMMLOWP_NEON_32
-#endif
-
-#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_64)
-#define GEMMLOWP_NEON_64
-#endif
-
-// Detect MIPS MSA.
-// Limit MSA optimizations to little-endian CPUs for now.
-// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
-#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \
- defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
-#define GEMMLOWP_MSA
-#endif
-
-// Convenience MIPS MSA tokens for 32-bit or 64-bit.
-#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32)
-#define GEMMLOWP_MSA_32
-#endif
-
-#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64)
-#define GEMMLOWP_MSA_64
-#endif
-
-// Detect SSE.
-#ifdef __SSE4_1__
-#define GEMMLOWP_SSE4
-#endif
-
-#ifdef __SSE3__
-#define GEMMLOWP_SSE3
-#endif
-
-// Convenience SSE4 tokens for 32-bit or 64-bit
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \
- !defined(GEMMLOWP_DISABLE_SSE4)
-#define GEMMLOWP_SSE4_32
-#endif
-
-#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32)
-#define GEMMLOWP_SSE3_32
-#endif
-
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \
- !defined(GEMMLOWP_DISABLE_SSE4)
-#define GEMMLOWP_SSE4_64
-#endif
-
-#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64)
-#define GEMMLOWP_SSE3_64
-#endif
-
-#if defined(__has_feature)
-#if __has_feature(memory_sanitizer)
-#include <sanitizer/msan_interface.h>
-#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __msan_unpoison
-#elif __has_feature(address_sanitizer)
-#include <sanitizer/asan_interface.h>
-#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __asan_unpoison_memory_region
-#endif
-#endif
-
-#endif // GEMMLOWP_ALLOW_INLINE_ASM
-
-// Detect Android. Don't conflate with ARM - we care about tuning
-// for non-ARM Android devices too. This can be used in conjunction
-// with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs.
-#if defined(__ANDROID__) || defined(ANDROID)
-#define GEMMLOWP_ANDROID
-#endif
-
namespace gemmlowp {
// Standard cache line size. Useful to optimize alignment and
@@ -242,7 +107,12 @@ const float kDefaultL2RhsFactor = 0.75f;
// size, so any size would work there. Different platforms may set this
// to different values but must ensure that their own optimized packing paths
// are consistent with this value.
+
+#ifdef GEMMLOWP_AVX2
+const int kRegisterSize = 32;
+#else
const int kRegisterSize = 16;
+#endif
// Hints the CPU to prefetch the cache line containing ptr.
inline void Prefetch(const void* ptr) {
diff --git a/internal/detect_platform.h b/internal/detect_platform.h
new file mode 100644
index 0000000..6f06d19
--- /dev/null
+++ b/internal/detect_platform.h
@@ -0,0 +1,166 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// detect_platform.h: Sets up macros that control architecture-specific
+// features of gemmlowp's implementation.
+
+#ifndef GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_
+#define GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_
+
+// Our inline assembly path assume GCC/Clang syntax.
+// Native Client doesn't seem to support inline assembly(?).
+#if defined(__GNUC__) && !defined(__native_client__)
+#define GEMMLOWP_ALLOW_INLINE_ASM
+#endif
+
+// Define macro statement that avoids inlining for GCC.
+// For non-GCC, define as empty macro.
+#if defined(__GNUC__)
+#define GEMMLOWP_NOINLINE __attribute__((noinline))
+#else
+#define GEMMLOWP_NOINLINE
+#endif
+
+// Detect ARM, 32-bit or 64-bit
+#ifdef __arm__
+#define GEMMLOWP_ARM_32
+#endif
+
+#ifdef __aarch64__
+#define GEMMLOWP_ARM_64
+#endif
+
+#if defined(GEMMLOWP_ARM_32) || defined(GEMMLOWP_ARM_64)
+#define GEMMLOWP_ARM
+#endif
+
+// Detect MIPS, 32-bit or 64-bit
+#if defined(__mips) && !defined(__LP64__)
+#define GEMMLOWP_MIPS_32
+#endif
+
+#if defined(__mips) && defined(__LP64__)
+#define GEMMLOWP_MIPS_64
+#endif
+
+#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MIPS
+#endif
+
+// Detect x86, 32-bit or 64-bit
+#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386)
+#define GEMMLOWP_X86_32
+#endif
+
+#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64)
+#define GEMMLOWP_X86_64
+#endif
+
+#if defined(GEMMLOWP_X86_32) || defined(GEMMLOWP_X86_64)
+#define GEMMLOWP_X86
+#endif
+
+// Some of our optimized paths use inline assembly and for
+// now we don't bother enabling some other optimized paths using intrinddics
+// where we can't use inline assembly paths.
+#ifdef GEMMLOWP_ALLOW_INLINE_ASM
+
+// Detect NEON. It's important to check for both tokens.
+#if (defined __ARM_NEON) || (defined __ARM_NEON__)
+#define GEMMLOWP_NEON
+#endif
+
+// Convenience NEON tokens for 32-bit or 64-bit
+#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_32)
+#define GEMMLOWP_NEON_32
+#endif
+
+#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_64)
+#define GEMMLOWP_NEON_64
+#endif
+
+// Detect MIPS MSA.
+// Limit MSA optimizations to little-endian CPUs for now.
+// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
+#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \
+ defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+#define GEMMLOWP_MSA
+#endif
+
+// Convenience MIPS MSA tokens for 32-bit or 64-bit.
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32)
+#define GEMMLOWP_MSA_32
+#endif
+
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MSA_64
+#endif
+
+// compiler define for AVX2 -D GEMMLOWP_ENABLE_AVX2
+// Detect AVX2
+#if defined(__AVX2__) && defined(GEMMLOWP_ENABLE_AVX2)
+#define GEMMLOWP_AVX2
+// Detect SSE4.
+// MSVC does not have __SSE4_1__ macro, but will enable SSE4
+// when AVX is turned on.
+#elif defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__))
+#define GEMMLOWP_SSE4
+// Detect SSE3.
+#elif defined(__SSE3__)
+#define GEMMLOWP_SSE3
+#endif
+
+// Convenience SSE4 tokens for 32-bit or 64-bit
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \
+ !defined(GEMMLOWP_DISABLE_SSE4)
+#define GEMMLOWP_SSE4_32
+#endif
+
+#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32)
+#define GEMMLOWP_SSE3_32
+#endif
+
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \
+ !defined(GEMMLOWP_DISABLE_SSE4)
+#define GEMMLOWP_SSE4_64
+#endif
+
+#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64)
+#define GEMMLOWP_SSE3_64
+#endif
+
+#if defined(GEMMLOWP_AVX2) && defined(GEMMLOWP_X86_64)
+#define GEMMLOWP_AVX2_64
+#endif
+
+#if defined(__has_feature)
+#if __has_feature(memory_sanitizer)
+#include <sanitizer/msan_interface.h>
+#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __msan_unpoison
+#elif __has_feature(address_sanitizer)
+#include <sanitizer/asan_interface.h>
+#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __asan_unpoison_memory_region
+#endif
+#endif
+
+#endif // GEMMLOWP_ALLOW_INLINE_ASM
+
+// Detect Android. Don't conflate with ARM - we care about tuning
+// for non-ARM Android devices too. This can be used in conjunction
+// with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs.
+#if defined(__ANDROID__) || defined(ANDROID)
+#define GEMMLOWP_ANDROID
+#endif
+
+#endif // GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_
diff --git a/internal/dispatch_gemm_shape.h b/internal/dispatch_gemm_shape.h
index 0be0bf3..ba4f341 100644
--- a/internal/dispatch_gemm_shape.h
+++ b/internal/dispatch_gemm_shape.h
@@ -85,6 +85,22 @@ struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
}
};
+template <VectorShape Shape>
+struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
+ typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
+ static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
+ typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
+ DstType;
+ static DstType Run(const SrcType& src) {
+ DstType dst;
+ dst.result_fixedpoint_multiplier =
+ Transpose(src.result_fixedpoint_multiplier);
+ dst.result_exponent = Transpose(src.result_exponent);
+ dst.result_offset_after_shift = src.result_offset_after_shift;
+ return dst;
+ }
+};
+
template <typename VectorMapType>
struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
typedef OutputStageBiasAddition<VectorMapType> SrcType;
diff --git a/internal/kernel.h b/internal/kernel.h
index 825a7f3..3120216 100644
--- a/internal/kernel.h
+++ b/internal/kernel.h
@@ -145,12 +145,24 @@ struct KernelSideFormat {
static const int kCells = tCells;
static const int kWidth = kCells * Cell::kWidth;
static const int kDepth = Cell::kDepth;
- typedef std::uint8_t Scalar;
+ typedef std::uint8_t Scalar; // The scalar type of the Format.
+ typedef std::uint8_t InputScalar; // The scalar type of the original input.
};
+// KernelSideFormat for int8 fast kernel trick. The original input is uint8, but
+// packs converts it to int8.
template <typename tCellFormat, int tCells>
struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> {
typedef std::int8_t Scalar;
+ typedef std::uint8_t InputScalar;
+};
+
+// KernelSideFormat for int8 inputs, enabling int8 fast kernel trick without
+// pack conversion.
+template <typename tCellFormat, int tCells>
+struct KernelSideFormatInt8Inputs : KernelSideFormat<tCellFormat, tCells> {
+ typedef std::int8_t Scalar;
+ typedef std::int8_t InputScalar;
};
// KernelFormat describes fully the input data layout that a kernel expects.
@@ -216,19 +228,24 @@ struct KernelBase {
virtual ~KernelBase() {}
};
-template <typename KernelScalarType>
+template <typename InputKernelScalarType, typename KernelScalarType>
struct ZeroPointInputValue {};
template <>
-struct ZeroPointInputValue<std::uint8_t> {
+struct ZeroPointInputValue<std::uint8_t, std::uint8_t> {
static constexpr std::uint8_t kValue = 0;
};
template <>
-struct ZeroPointInputValue<std::int8_t> {
+struct ZeroPointInputValue<std::uint8_t, std::int8_t> {
static constexpr std::uint8_t kValue = 128;
};
+template <>
+struct ZeroPointInputValue<std::int8_t, std::int8_t> {
+ static constexpr std::uint8_t kValue = 0;
+};
+
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_KERNEL_H_
diff --git a/internal/kernel_avx.h b/internal/kernel_avx.h
new file mode 100644
index 0000000..2fe1249
--- /dev/null
+++ b/internal/kernel_avx.h
@@ -0,0 +1,361 @@
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// kernel_SSE.h: a collection of Intel SSE optimized kernels.
+// Check in kernel_default.h which one(s) are actually used by default.
+// Others are mere experiments; they are still covered by tests
+// in case they might be useful some day.
+//
+
+#ifndef GEMMLOWP_INTERNAL_KERNEL_AVX_H_
+#define GEMMLOWP_INTERNAL_KERNEL_AVX_H_
+
+#include "kernel.h"
+
+#include <string.h>
+#include <cassert>
+
+namespace gemmlowp {
+
+#ifdef GEMMLOWP_AVX2_64
+struct AVX2_64_Kernel24x8Depth2 : KernelBase {
+ typedef KernelFormat<KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1>>
+ Format;
+
+ const char *Name() const override { return "AVX, 24x8, depth 2"; }
+
+ void Run(std::int32_t *dst_ptr, std::size_t dst_row_stride, std::size_t dst_col_stride,
+ const std::uint8_t *lhs_ptr, const std::uint8_t *rhs_ptr, std::size_t start_depth,
+ std::size_t run_depth) const override {
+ ScopedProfilingLabel label("optimized kernel");
+ assert(dst_row_stride == 1);
+ const std::int64_t run_depth_cells = run_depth / Format::kDepth;
+ const std::int64_t dst_col_stride_q = dst_col_stride;
+
+ /* Main loop */
+
+ // A 2x8 cell of Rhs is stored in 16bit in ymm1 .
+ // A 24x2 block of 3 8x2 cells Lhs is stored in 16bit in ymm0, replaced
+ // every Iteration.
+ // A 8x8 block of accumulators is stored in 32bit in xmm4--xmm15.
+ //
+ // +-------+-------+-------+-------+
+ // |ymm1[0] |ymm2[2] |
+ // Rhs +-------+---------------+-------+
+ // |ymm1[1] |ymm1[4] |
+ // +-------+-------+-------+-------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +--+--+ - - - - +-------+-------+-------+-------+
+ // |ymm0 | | ymm4 | ymm5 | ymm6 | ymm7 |
+ // |ymm0 | (Iter1) | ymm4 | ymm5 | ymm6 | ymm7 |
+ // |ymm0 | | ymm4 | ymm5 | ymm6 | ymm7 |
+ // |ymm0 | | ymm4 | ymm5 | ymm6 | ymm7 |
+ // +--+--+ - - - - +-------+-------+-------+-------+
+ // |ymm0 | | ymm8 | ymm9 | ymm10 | ymm11 |
+ // |ymm0 | (Iter2) | ymm8 | ymm9 | ymm10 | ymm11 |
+ // |ymm0 | | ymm8 | ymm9 | ymm10 | ymm11 |
+ // |ymm0 | | ymm8 | ymm9 | ymm10 | ymm11 |
+ // +--+--+ - - - - +-------+-------+-------+-------+
+ // |ymm0 | | ymm12 | ymm13 | ymm14 | ymm15 |
+ // |ymm0 | (Iter3) | ymm12 | ymm13 | ymm14 | ymm15 |
+ // |ymm0 | | ymm12 | ymm13 | ymm14 | ymm15 |
+ // |ymm0 | | ymm12 | ymm13 | ymm14 | ymm15 |
+ // +--+--+ - - - - +-------+-------+-------+-------+
+ //
+ // Accumulator
+
+ asm volatile(
+ // Set registers for destination
+ "movq %[dst_col_stride_q], %%r12\n\t" // stride is r12
+ "shlq $2, %%r12\n\t" // set stride dword
+ "leaq (%%r12,%%r12,0x2), %%r13\n\t" // load stride aligned r13
+
+ // Set accumulators to zero.
+ "vpxor %%ymm4, %%ymm4, %%ymm4 \n\t" // zero accumulators
+ "vpxor %%ymm5, %%ymm5, %%ymm5 \n\t" // zero accumulators
+ "vpxor %%ymm6, %%ymm6, %%ymm6 \n\t" // zero accumulators
+ "vpxor %%ymm7, %%ymm7, %%ymm7 \n\t" // zero accumulators
+ "vpxor %%ymm8, %%ymm8, %%ymm8 \n\t" // zero accumulators
+ "vpxor %%ymm9, %%ymm9, %%ymm9 \n\t" // zero accumulators
+ "vpxor %%ymm10, %%ymm10, %%ymm10\n\t" // zero accumulators
+ "vpxor %%ymm11, %%ymm11, %%ymm11\n\t" // zero accumulators
+ "vpxor %%ymm12, %%ymm12, %%ymm12\n\t" // zero accumulators
+ "vpxor %%ymm13, %%ymm13, %%ymm13\n\t" // zero accumulators
+ "vpxor %%ymm14, %%ymm14, %%ymm14\n\t" // zero accumulators
+ "vpxor %%ymm15, %%ymm15, %%ymm15\n\t" // zero accumulators
+
+ "movq %[run_depth_cells], %%r14 \n\t" // load cell depth r14
+ "subq $2, %%r14 \n\t" // cell depth is 2
+ "js outerLoop1%= \n\t" // outerloop for matrix
+
+ // Loop for K unrolled by 4
+ "outerLoop2%=: \n\t" // outer loop unroll
+
+ // K = 0,1,2,3
+ // RHS cell to ymm1
+
+ // lower half
+ "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
+ "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+ // LHS cell elements 0 and 1
+ "vpmovzxbw 0x00(%[lhs_ptr]), %%ymm0\n\t" // mov lhs to ymm0
+ "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to all ymm2
+ "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to all ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs0 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs1 into ymm3
+ "vpaddd %%ymm2, %%ymm4, %%ymm4 \n\t" // add muladd lhs + rhs0 into ymm4
+ "vpaddd %%ymm3, %%ymm5, %%ymm5 \n\t" // add muladd lhs + rhs1 into ymm5
+ // LHS cell elements 2 and 3
+ "vpshufd $0xaa, %%ymm1, %%ymm2 \n\t" // move rhs 2 element to all ymm2
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rh3 into ymm2
+ "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element into all ymm3
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rh4 into ymm3
+ "vpaddd %%ymm2, %%ymm6, %%ymm6 \n\t" // add muladd lhs + rhs2 into ymm6
+ "vpaddd %%ymm3, %%ymm7, %%ymm7 \n\t" // add muladd lhs + rhs3 into ymm7
+
+ // cache prefect lhs //see if it works better?
+ //"prefetcht0 0x80(%[lhs_ptr]) \n\t" //prefetch cache lines
+ "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
+ "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+
+ // K = 5,6,7,8
+ // next LHS cell elements 0 and 1
+ "vpmovzxbw 0x10(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
+ "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // mov rhs 0 element to all ymm2
+ "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // mov rhs 1 element to all ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs0 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs1 into ymm3
+ "vpaddd %%ymm2, %%ymm8, %%ymm8 \n\t" // add muladd lhs + rhs0 into ymm8
+ "vpaddd %%ymm3, %%ymm9, %%ymm9 \n\t" // add muladd lhs + rhs1 into ymm9
+ // next LHS cell elements 2 and 3
+ "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to all ymm2
+ "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to all ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs3 into ymm3
+ "vpaddd %%ymm2, %%ymm10, %%ymm10 \n\t" // add muladd lhs + rhs2 into ymm10
+ "vpaddd %%ymm3, %%ymm11, %%ymm11 \n\t" // add muladd lhs + rhs3 into ymm11
+
+ // rhs lower half
+ "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
+ "vpermq $0x44,%%ymm1, %%ymm1 \n\t" // duplcate lower 16
+
+ // next LHS cell elements 0 and 1
+ "vpmovzxbw 0x20(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
+ "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // mov rhs 0 element to all ymm2
+ "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // mov rhs 1 element to all ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs0 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs1 into ymm3
+ "vpaddd %%ymm2, %%ymm12, %%ymm12 \n\t" // add muladd lhs + rhs0 into ymm8
+ "vpaddd %%ymm3, %%ymm13, %%ymm13 \n\t" // add muladd lhs + rhs1 into ymm9
+
+ // cache prefetch rhs //see if it works better?
+ //"prefetcht0 0x80(%[rhs_ptr]) \n\t"
+
+ // next LHS cell elements 2 and 3
+ "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to all ymm2
+ "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to all ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs3 into ymm3
+ "vpaddd %%ymm2, %%ymm14, %%ymm14 \n\t" // add muladd lhs + rhs2 into ymm10
+ "vpaddd %%ymm3, %%ymm15, %%ymm15 \n\t" // add muladd lhs + rhs3 into ymm11
+
+ // current result in ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10 ymm11 ymm12 ymm13 ymm14 ymm15
+
+ // rhs+10 lower half
+ "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
+ "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+ // next LHS cell elements 0 and 1
+ "vpmovzxbw 0x30(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
+ "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to ymm2
+ "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs0 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs1 into ymm3
+ "vpaddd %%ymm2, %%ymm4, %%ymm4 \n\t" // accumulate to ymm4
+ "vpaddd %%ymm3, %%ymm5, %%ymm5 \n\t" // accumulate to ymm5
+ // next LHS cell elements 2 and 3
+ "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to ymm2
+ "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to ymm2
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mull add lhs rhs3 into ymm3
+ "vpaddd %%ymm2, %%ymm6, %%ymm6 \n\t" // add lhs rhs2 to ymm6
+ "vpaddd %%ymm3, %%ymm7, %%ymm7 \n\t" // add lhs rhs3 to ymm7
+
+ // rhs+10 lower half
+ "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
+ "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+
+ // next LHS cell elements 4 and 5
+ "vpmovzxbw 0x40(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
+ "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to ymm2
+ "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs0 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs1 into ymm3
+ "vpaddd %%ymm2, %%ymm8, %%ymm8 \n\t" // accumulate to ymm8
+ "vpaddd %%ymm3, %%ymm9, %%ymm9 \n\t" // accumulate to ymm9
+ // next LHS cell elements 6 and 7
+ "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to ymm2
+ "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to ymm2
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mull add lhs rhs3 into ymm3
+ "vpaddd %%ymm2, %%ymm10, %%ymm10 \n\t" // add lhs rhs2 to ymm10
+ "vpaddd %%ymm3, %%ymm11, %%ymm11 \n\t" // add lhs rhs3 to ymm11
+
+ "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1
+ "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+ // next LHS cell elements 9 and 10
+ "vpmovzxbw 0x50(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0
+ "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to ymm2
+ "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs0 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs1 into ymm3
+ "vpaddd %%ymm2, %%ymm12, %%ymm12 \n\t" // accumulate to ymm12
+ "vpaddd %%ymm3, %%ymm13, %%ymm13 \n\t" // accumulate to ymm13
+
+ // next LHS cell elements 11 and 12
+ "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to ymm2
+ "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to ymm2
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mull add lhs rhs3 into ymm3
+ "vpaddd %%ymm2, %%ymm14, %%ymm14 \n\t" // add lhs rhs2 to ymm14
+ "vpaddd %%ymm3, %%ymm15, %%ymm15 \n\t" // add lhs rhs3 to ymm15
+
+ // completed rhs+10
+ "addq $0x60, %[lhs_ptr] \n\t" // increment stride lhs
+ "addq $0x10, %[rhs_ptr] \n\t" // increment stride rhs
+
+ "subq $2, %[run_depth_cells] \n\t"
+ "ja outerLoop2%= \n\t"
+
+ "movq %[run_depth_cells], %%r14 \n\t"
+ "decq %%r14 \n\t"
+ "js finish%= \n\t"
+
+ // Loop for K unrolled by 2
+ "outerLoop1%=: \n\t"
+
+ // rhs lower
+ "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // get rhs into ymm1
+ "vpermq $0x44,%%ymm1, %%ymm1 \n\t"
+
+ // LHS cell
+ "vpmovzxbw (%[lhs_ptr]), %%ymm0 \n\t" // lhs in into ymm0
+ "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // rhs element 0 into ymm2
+ "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // rhs element 1 into ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 0 ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 1 ymm3
+ "vpaddd %%ymm2, %%ymm4, %%ymm4 \n\t" // acc element 0 ymm4
+ "vpaddd %%ymm3, %%ymm5, %%ymm5 \n\t" // acc element 1 ymm5
+ "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // rhs element 2 into ymm2
+ "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // rhs element 3 into ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 2 ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 3 ymm3
+ "vpaddd %%ymm2, %%ymm6, %%ymm6 \n\t" // acc element 2 into ymm6
+ "vpaddd %%ymm3, %%ymm7, %%ymm7 \n\t" // acc element 3 into ymm7
+
+ // lhs+10
+ "vpmovzxbw 0x10(%[lhs_ptr]), %%ymm0 \n\t" // lhs in into ymm0
+ "vpshufd $0x00, %%ymm1, %%ymm2 \n\t" // rhs element 0 into ymm2
+ "vpshufd $0x55, %%ymm1, %%ymm3 \n\t" // rhs element 1 into ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 0 ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 1 ymm3
+ "vpaddd %%ymm2, %%ymm8, %%ymm8 \n\t" // acc element 0 ymm8
+ "vpaddd %%ymm3, %%ymm9, %%ymm9 \n\t" // acc element 1 ymm9
+ "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // rhs element 2 into ymm2
+ "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // rhs element 3 into ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 2 ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 3 ymm3
+ "vpaddd %%ymm2, %%ymm10, %%ymm10 \n\t" // acc element 2 into ymm10
+ "vpaddd %%ymm3, %%ymm11, %%ymm11 \n\t" // acc element 3 into ymm11
+
+ "vpmovzxbw 0x20(%[lhs_ptr]), %%ymm0 \n\t"
+ "vpshufd $0x00, %%ymm1, %%ymm2 \n\t" // rhs element 0 into ymm2
+ "vpshufd $0x55, %%ymm1, %%ymm3 \n\t" // rhs element 1 into ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 0 ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 1 ymm3
+ "vpaddd %%ymm2, %%ymm12, %%ymm12 \n\t" // acc element 0 ymm12
+ "vpaddd %%ymm3, %%ymm13, %%ymm13 \n\t" // acc element 1 ymm13
+ "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // rhs element 2 into ymm2
+ "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // rhs element 3 into ymm3
+ "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 2 ymm2
+ "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 3 ymm3
+ "vpaddd %%ymm2, %%ymm14, %%ymm14 \n\t" // acc element 2 into ymm14
+ "vpaddd %%ymm3, %%ymm15, %%ymm15 \n\t" // acc element 3 into ymm15
+
+ // update matrix pointers
+ "addq $0x30, %[lhs_ptr] \n\t"
+ "addq $0x08, %[rhs_ptr] \n\t"
+
+ "decq %[run_depth_cells] \n\t"
+ "jnz outerLoop1%= \n\t"
+
+ "finish%=:\n\t"
+
+ "test %[start_depth], %[start_depth] \n\t"
+ "jz storeDst%= \n\t"
+
+ "vpaddd 0x00(%[dst_ptr]), %%ymm4, %%ymm4 \n\t" // rhs0
+ "vpaddd 0x20(%[dst_ptr]), %%ymm8, %%ymm8 \n\t" // rhs0
+ "vpaddd 0x40(%[dst_ptr]), %%ymm12, %%ymm12 \n\t" // rhs0
+
+ "vpaddd 0x00(%[dst_ptr], %%r12, 1) , %%ymm5, %%ymm5 \n\t" // rhs1
+ "vpaddd 0x20(%[dst_ptr], %%r12, 1) , %%ymm9, %%ymm9 \n\t" // rhs1
+ "vpaddd 0x40(%[dst_ptr], %%r12, 1) , %%ymm13, %%ymm13 \n\t" // rhs1
+
+ "vpaddd 0x00(%[dst_ptr], %%r12, 2) , %%ymm6, %%ymm6 \n\t" // rhs2
+ "vpaddd 0x20(%[dst_ptr], %%r12, 2) , %%ymm10, %%ymm10 \n\t" // rhs2
+ "vpaddd 0x40(%[dst_ptr], %%r12, 2) , %%ymm14, %%ymm14 \n\t" // rhs2
+
+ "vpaddd 0x00(%[dst_ptr], %%r13, 1) , %%ymm7, %%ymm7 \n\t" // rhs3
+ "vpaddd 0x20(%[dst_ptr], %%r13, 1) , %%ymm11, %%ymm11 \n\t" // rhs3
+ "vpaddd 0x40(%[dst_ptr], %%r13, 1) , %%ymm15, %%ymm15 \n\t" // rhs3
+
+ "storeDst%=:\n\t"
+
+ "vmovdqu %%ymm4, 0x00(%[dst_ptr]) \n\t" // rhs0
+ "vmovdqu %%ymm8, 0x20(%[dst_ptr]) \n\t" // rhs0
+ "vmovdqu %%ymm12, 0x40(%[dst_ptr]) \n\t" // rhs0
+
+ "vmovdqu %%ymm5, 0x00(%[dst_ptr], %%r12, 1) \n\t" // rhs1
+ "vmovdqu %%ymm9, 0x20(%[dst_ptr], %%r12, 1) \n\t" // rhs1
+ "vmovdqu %%ymm13, 0x40(%[dst_ptr], %%r12, 1) \n\t" // rhs1
+
+ "vmovdqu %%ymm6, 0x00(%[dst_ptr], %%r12, 2) \n\t" // rhs2
+ "vmovdqu %%ymm10, 0x20(%[dst_ptr], %%r12, 2) \n\t" // rhs2
+ "vmovdqu %%ymm14, 0x40(%[dst_ptr], %%r12, 2) \n\t" // rhs2
+
+ "vmovdqu %%ymm7, 0x00(%[dst_ptr], %%r13, 1) \n\t" // rhs3
+ "vmovdqu %%ymm11, 0x20(%[dst_ptr], %%r13, 1) \n\t" // rhs3
+ "vmovdqu %%ymm15, 0x40(%[dst_ptr], %%r13, 1) \n\t" // rhs3
+
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [dst_ptr] "+r"(dst_ptr)
+ : // inputs
+ [start_depth] "r"(start_depth), [dst_col_stride_q] "r"(dst_col_stride_q),
+ [run_depth_cells] "r"(run_depth_cells)
+ : // clobbers
+ "cc", "memory", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7",
+ "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15", "%r12",
+ "%r13", "%r14");
+ }
+};
+#endif
+
+} // namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_KERNEL_AVX_H_
diff --git a/internal/kernel_default.h b/internal/kernel_default.h
index a919ffe..29b0991 100644
--- a/internal/kernel_default.h
+++ b/internal/kernel_default.h
@@ -20,66 +20,84 @@
#include "../public/bit_depth.h"
#include "common.h"
+#include "kernel.h"
#include "kernel_reference.h"
namespace gemmlowp {
-template <bool MaxProductIsLessThan4096, bool LhsAlwaysNonzero>
+template <bool MaxProductIsLessThan4096, bool IsUnsigned, bool LhsNonZero>
struct DefaultKernelImpl {};
// Partial specialization implementing the logic that if we want to use
-// a kernel for LhsAlwaysNonzero but do not have such a kernel, then we fall
-// back to a generic kernel not taking advantage of LhsAlwaysNonzero.
-template <bool LhsAlwaysNonzero>
-struct DefaultKernelImpl<true, LhsAlwaysNonzero>
- : DefaultKernelImpl<false, LhsAlwaysNonzero> {};
-
-// Partial specialization implementing the logic that if we want to use
// a kernel for MaxProductIsLessThan4096 but do not have such a kernel, then we
// fall back to a generic kernel not taking advantage of
// MaxProductIsLessThan4096.
+template <bool LhsNonZero>
+struct DefaultKernelImpl<true, true, LhsNonZero>
+ : DefaultKernelImpl<false, true, LhsNonZero> {};
+
+// Partial specialization implementing the logic that if we want to use
+// a kernel for LhsNonZero but do not have such a kernel, then we fall
+// back to a generic kernel not taking advantage of LhsNonZero.
template <bool MaxProductIsLessThan4096>
-struct DefaultKernelImpl<MaxProductIsLessThan4096, true>
- : DefaultKernelImpl<MaxProductIsLessThan4096, false> {};
+struct DefaultKernelImpl<MaxProductIsLessThan4096, true, true>
+ : DefaultKernelImpl<MaxProductIsLessThan4096, true, false> {};
template <typename BitDepthParams>
struct DefaultKernel
: DefaultKernelImpl<(BitDepthParams::LhsRange::kMaxValue *
BitDepthParams::RhsRange::kMaxValue <
4096),
- (BitDepthParams::LhsRange::kMinValue > 0)> {};
+ (BitDepthParams::LhsRange::kMinValue >= 0),
+ (BitDepthParams::LhsRange::kMinValue > 0 ||
+ (BitDepthParams::LhsRange::kMaxValue <= 127 &&
+ BitDepthParams::LhsRange::kMinValue > -128))> {};
} // end namespace gemmlowp
-#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \
- LhsAlwaysNonzero, Kernel) \
- namespace gemmlowp { \
- template <> \
- struct DefaultKernelImpl<MaxProductIsLessThan4096, LhsAlwaysNonzero> \
- : Kernel {}; \
+#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, IsUnsigned, \
+ LhsAlwaysNonZero, Kernel) \
+ namespace gemmlowp { \
+ template <> \
+ struct DefaultKernelImpl<MaxProductIsLessThan4096, IsUnsigned, \
+ LhsAlwaysNonZero> : Kernel {}; \
}
+// User-provided int8 inputs is only supported in the NEON path currently.
#if defined GEMMLOWP_NEON_32
#include "kernel_neon.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_32_Kernel12x4Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(true, false,
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_32_Kernel12x4Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(true, true, false,
NEON_32_Kernel12x4Depth2Assuming12BitProducts)
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true,
NEON_32bit_GEMM_Int8Operands_LhsNonzero)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true,
+ NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs)
#elif defined GEMMLOWP_NEON_64
#include "kernel_neon.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2)
-GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
+#if defined GEMMLOWP_DOTPROD_KERNEL
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false,
+ NEON_64_Kernel12x8Depth4_dotprod)
+#else
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_64_Kernel12x8Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true,
NEON_64bit_GEMM_Int8Operands_LhsNonzero)
+#endif
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true,
+ NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs)
#elif defined(GEMMLOWP_MSA)
#include "kernel_msa.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, MSA_Kernel12x8Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, MSA_Kernel12x8Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, MSA_GEMM_Int8Operands_LhsNonzero)
#elif defined GEMMLOWP_SSE4_32
#include "kernel_sse.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_32_Kernel4x4Depth2)
#elif defined GEMMLOWP_SSE4_64
#include "kernel_sse.h"
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_64_Kernel12x4Depth2)
+#elif defined GEMMLOWP_AVX2_64
+#include "kernel_avx.h"
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, AVX2_64_Kernel24x8Depth2)
#else
#include "kernel_reference.h"
namespace gemmlowp {
@@ -88,7 +106,7 @@ typedef ReferenceKernel<KernelFormat<
KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > >
DefaultReferenceKernel;
}
-GEMMLOWP_SET_DEFAULT_KERNEL(false, false, DefaultReferenceKernel)
+GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, DefaultReferenceKernel)
#endif
#endif // GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_
diff --git a/internal/kernel_msa.h b/internal/kernel_msa.h
index 4985b73..a9205f6 100644
--- a/internal/kernel_msa.h
+++ b/internal/kernel_msa.h
@@ -42,8 +42,8 @@ namespace gemmlowp {
// Our main GEMM kernel.
struct MSA_Kernel12x8Depth2 : KernelBase {
- typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
- KernelSideFormat<CellFormat<4, 2>, 2> >
+ typedef KernelFormat<KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2> >
Format;
const char* Name() const override { return "MSA, 12x8, depth 2"; }
@@ -62,9 +62,6 @@ struct MSA_Kernel12x8Depth2 : KernelBase {
assert(dst_row_stride == 1);
asm volatile(
- // Set a temp to all zeroes.
- "ldi.b $w31, 0\n"
-
// Multiply dst_col_stride by 4 == sizeof(int32) to use
// it as a byte offset below.
GEMMLOWP_MIPS_XSLL
@@ -75,32 +72,25 @@ struct MSA_Kernel12x8Depth2 : KernelBase {
"beqz %[start_depth], " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
// Load accumulators (start_depth != 0).
- GEMMLOWP_MIPS_XADDU
- " $a0, %[dst_ptr], %[dst_col_stride]\n"
+ GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n"
"ld.w $w0, (0*16)(%[dst_ptr])\n"
"ld.w $w4, (1*16)(%[dst_ptr])\n"
- "ld.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
- " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
"ld.w $w1, (0*16)($a0)\n"
"ld.w $w5, (1*16)($a0)\n"
- "ld.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
- " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
"ld.w $w2, (0*16)($a1)\n"
"ld.w $w6, (1*16)($a1)\n"
- "ld.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
- " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
"ld.w $w3, (0*16)($a0)\n"
"ld.w $w7, (1*16)($a0)\n"
- "ld.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
- " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
"ld.w $w12, (0*16)($a1)\n"
"ld.w $w16, (1*16)($a1)\n"
- "ld.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
- " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
"ld.w $w13, (0*16)($a0)\n"
"ld.w $w17, (1*16)($a0)\n"
- "ld.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
- " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
"ld.w $w14, (0*16)($a1)\n"
"ld.w $w18, (1*16)($a1)\n"
"ld.w $w22, (2*16)($a1)\n"
@@ -109,8 +99,7 @@ struct MSA_Kernel12x8Depth2 : KernelBase {
"ld.w $w23, (2*16)($a0)\n"
"b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
- GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
- ":\n"
+ GEMMLOWP_LABEL_CLEAR_ACCUMULATORS ":\n"
// Clear accumulators (start_depth == 0).
"ldi.w $w0, 0\n"
"ldi.w $w4, 0\n"
@@ -139,17 +128,16 @@ struct MSA_Kernel12x8Depth2 : KernelBase {
GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
- GEMMLOWP_LABEL_LOOP
- ":\n"
+ GEMMLOWP_LABEL_LOOP ":\n"
// Overview of register layout:
//
- // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
+ // A half of the 2 2x4 cells of Rhs is stored in 16bit in w28-w31
// (each register contains 4 replicas of a pair of elements).
// A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
// A 12x8 block of accumulators is stored in 32bit in w0-w23.
//
// +------+------+------+------+
- // Rhs |w27 |w28 |w29 |w30 |
+ // Rhs |w28 |w29 |w30 |w31 |
// +------+------+------+------+
//
// | | | | |
@@ -179,128 +167,86 @@ struct MSA_Kernel12x8Depth2 : KernelBase {
"ld.b $w24, 0(%[lhs_ptr])\n"
"ld.b $w25, 8(%[lhs_ptr])\n"
- // Load 4 bytes of rhs[] for the first half of depth 0.
- "lbu $a0, 0(%[rhs_ptr])\n"
- "lbu $a1, 1(%[rhs_ptr])\n"
- "lbu $a2, 2(%[rhs_ptr])\n"
- "lbu $a3, 3(%[rhs_ptr])\n"
- // Load 4 bytes of rhs[] for the first half of depth 1.
- "lbu $v0, 4(%[rhs_ptr])\n"
- "lbu $v1, 5(%[rhs_ptr])\n"
- "lbu $t8, 6(%[rhs_ptr])\n"
- "lbu $t9, 7(%[rhs_ptr])\n"
+ // Load 2 x 8 bytes of rhs[].
+ "ld.b $w27, 0(%[rhs_ptr])\n"
// Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ldi.b $w31, 0\n"
"ilvr.b $w24, $w31, $w24\n"
"ilvl.b $w26, $w31, $w25\n"
"ilvr.b $w25, $w31, $w25\n"
- // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
- "ilvl.d $w27, $w31, $w24\n"
- "ilvl.d $w28, $w31, $w25\n"
- "ilvl.d $w29, $w31, $w26\n"
- "ilvr.h $w24, $w27, $w24\n"
- "ilvr.h $w25, $w28, $w25\n"
- "ilvr.h $w26, $w29, $w26\n"
-
- // Combine and interleave depth 0 and depth 1 elements of rhs[] for
- // dpadd_u.w (for the first half).
- "ins $a0, $v0, 16, 8\n"
- "ins $a1, $v1, 16, 8\n"
- "ins $a2, $t8, 16, 8\n"
- "ins $a3, $t9, 16, 8\n"
- // Make 4 replicas of every pair of rhs[] elements.
- "fill.w $w27, $a0\n"
- "fill.w $w28, $a1\n"
- "fill.w $w29, $a2\n"
- "fill.w $w30, $a3\n"
-
- // Load 4 bytes of rhs[] for the second half of depth 0.
- "lbu $a0, 8(%[rhs_ptr])\n"
- "lbu $a1, 9(%[rhs_ptr])\n"
- "lbu $a2, 10(%[rhs_ptr])\n"
- "lbu $a3, 11(%[rhs_ptr])\n"
- // Load 4 bytes of rhs[] for the second half of depth 1.
- "lbu $v0, 12(%[rhs_ptr])\n"
- "lbu $v1, 13(%[rhs_ptr])\n"
- "lbu $t8, 14(%[rhs_ptr])\n"
- "lbu $t9, 15(%[rhs_ptr])\n"
// First half of depths 0 and 1.
- // Dot-product-(and)-add doubles multiplicand width.
- "dpadd_u.w $w0, $w24, $w27\n"
- "dpadd_u.w $w4, $w25, $w27\n"
- "dpadd_u.w $w8, $w26, $w27\n"
- "dpadd_u.w $w1, $w24, $w28\n"
- "dpadd_u.w $w5, $w25, $w28\n"
- "dpadd_u.w $w9, $w26, $w28\n"
- "dpadd_u.w $w2, $w24, $w29\n"
- "dpadd_u.w $w6, $w25, $w29\n"
- "dpadd_u.w $w10, $w26, $w29\n"
- "dpadd_u.w $w3, $w24, $w30\n"
- "dpadd_u.w $w7, $w25, $w30\n"
- "dpadd_u.w $w11, $w26, $w30\n"
-
- // Combine and interleave depth 0 and depth 1 elements of rhs[] for
- // dpadd_u.w (for the second half).
- "ins $a0, $v0, 16, 8\n"
- "ins $a1, $v1, 16, 8\n"
- "ins $a2, $t8, 16, 8\n"
- "ins $a3, $t9, 16, 8\n"
+ // Zero-extend 8-bit elements of rhs[] to 16 bits.
+ "ilvr.b $w31, $w31, $w27\n"
// Make 4 replicas of every pair of rhs[] elements.
- "fill.w $w27, $a0\n"
- "fill.w $w28, $a1\n"
- "fill.w $w29, $a2\n"
- "fill.w $w30, $a3\n"
+ "splati.w $w28, $w31[0]\n"
+ "splati.w $w29, $w31[1]\n"
+ "splati.w $w30, $w31[2]\n"
+ "splati.w $w31, $w31[3]\n"
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w24, $w28\n"
+ "dpadd_u.w $w4, $w25, $w28\n"
+ "dpadd_u.w $w8, $w26, $w28\n"
+ "dpadd_u.w $w1, $w24, $w29\n"
+ "dpadd_u.w $w5, $w25, $w29\n"
+ "dpadd_u.w $w9, $w26, $w29\n"
+ "dpadd_u.w $w2, $w24, $w30\n"
+ "dpadd_u.w $w6, $w25, $w30\n"
+ "dpadd_u.w $w10, $w26, $w30\n"
+ "dpadd_u.w $w3, $w24, $w31\n"
+ "dpadd_u.w $w7, $w25, $w31\n"
+ "dpadd_u.w $w11, $w26, $w31\n"
// Second half of depths 0 and 1.
+ // Zero-extend 8-bit elements of rhs[] to 16 bits.
+ "ldi.b $w31, 0\n"
+ "ilvl.b $w31, $w31, $w27\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "splati.w $w28, $w31[0]\n"
+ "splati.w $w29, $w31[1]\n"
+ "splati.w $w30, $w31[2]\n"
+ "splati.w $w31, $w31[3]\n"
// Dot-product-(and)-add doubles multiplicand width.
- "dpadd_u.w $w12, $w24, $w27\n"
- "dpadd_u.w $w16, $w25, $w27\n"
- "dpadd_u.w $w20, $w26, $w27\n"
- "dpadd_u.w $w13, $w24, $w28\n"
- "dpadd_u.w $w17, $w25, $w28\n"
- "dpadd_u.w $w21, $w26, $w28\n"
- "dpadd_u.w $w14, $w24, $w29\n"
- "dpadd_u.w $w18, $w25, $w29\n"
- "dpadd_u.w $w22, $w26, $w29\n"
- "dpadd_u.w $w15, $w24, $w30\n"
- "dpadd_u.w $w19, $w25, $w30\n"
- "dpadd_u.w $w23, $w26, $w30\n"
+ "dpadd_u.w $w12, $w24, $w28\n"
+ "dpadd_u.w $w16, $w25, $w28\n"
+ "dpadd_u.w $w20, $w26, $w28\n"
+ "dpadd_u.w $w13, $w24, $w29\n"
+ "dpadd_u.w $w17, $w25, $w29\n"
+ "dpadd_u.w $w21, $w26, $w29\n"
+ "dpadd_u.w $w14, $w24, $w30\n"
+ "dpadd_u.w $w18, $w25, $w30\n"
+ "dpadd_u.w $w22, $w26, $w30\n"
+ "dpadd_u.w $w15, $w24, $w31\n"
+ "dpadd_u.w $w19, $w25, $w31\n"
+ "dpadd_u.w $w23, $w26, $w31\n"
GEMMLOWP_MIPS_XADDIU " %[run_depth], -2\n" GEMMLOWP_MIPS_XADDIU
- " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU
- " %[rhs_ptr], 16\n"
+ " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
"bnez %[run_depth]," GEMMLOWP_LABEL_LOOP "b\n"
GEMMLOWP_LABEL_AFTER_LOOP ":\n"
// Store accumulators.
- GEMMLOWP_MIPS_XADDU
- " $a0, %[dst_ptr], %[dst_col_stride]\n"
+ GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n"
"st.w $w0, (0*16)(%[dst_ptr])\n"
"st.w $w4, (1*16)(%[dst_ptr])\n"
- "st.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
- " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
"st.w $w1, (0*16)($a0)\n"
"st.w $w5, (1*16)($a0)\n"
- "st.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
- " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
"st.w $w2, (0*16)($a1)\n"
"st.w $w6, (1*16)($a1)\n"
- "st.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
- " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
"st.w $w3, (0*16)($a0)\n"
"st.w $w7, (1*16)($a0)\n"
- "st.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
- " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
"st.w $w12, (0*16)($a1)\n"
"st.w $w16, (1*16)($a1)\n"
- "st.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
- " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n"
"st.w $w13, (0*16)($a0)\n"
"st.w $w17, (1*16)($a0)\n"
- "st.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
- " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n"
"st.w $w14, (0*16)($a1)\n"
"st.w $w18, (1*16)($a1)\n"
"st.w $w22, (2*16)($a1)\n"
@@ -308,18 +254,15 @@ struct MSA_Kernel12x8Depth2 : KernelBase {
"st.w $w19, (1*16)($a0)\n"
"st.w $w23, (2*16)($a0)\n"
: // outputs
- [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
- [run_depth] "+r"(run_depth),
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [run_depth] "+r"(run_depth),
[dst_col_stride] "+r"(dst_col_stride)
: // inputs
[dst_ptr] "r"(dst_ptr),
[start_depth] "r"(start_depth)
: // clobbers
- "memory", "v0", "v1", "a0", "a1", "a2", "a3", "t8", "t9", "$f0", "$f1",
- "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", "$f10", "$f11",
- "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
- "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29",
- "$f30", "$f31");
+ "memory", "a0", "a1", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9",
+ "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
+ "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
#undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
#undef GEMMLOWP_LABEL_BEFORE_LOOP
@@ -328,6 +271,303 @@ struct MSA_Kernel12x8Depth2 : KernelBase {
}
};
+// Fast kernel operating on int8 operands.
+// It is assumed that one of the two int8 operands only takes values
+// in [-127, 127], while the other may freely range in [-128, 127].
+// The issue with both operands taking the value -128 is that:
+// -128*-128 + -128*-128 == -32768 overflows int16.
+// Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16
+// range. That is the basic idea of this kernel.
+struct MSA_GEMM_Int8Operands_LhsNonzero : KernelBase {
+ typedef KernelFormat<
+ KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+ KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
+ Format;
+
+ const char* Name() const override {
+ return "MSA, 4x4, depth 16, accumulating two within signed int16";
+ }
+
+ // TODO(benoitjacob): reorder function arguments so dst comes last
+ void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
+ std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
+ const std::uint8_t* rhs_ptr, std::size_t start_depth,
+ std::size_t run_depth) const override {
+ (void)dst_row_stride;
+#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
+#define GEMMLOWP_LABEL_LOOP "2"
+#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
+#define GEMMLOWP_LABEL_STORE "4"
+ asm volatile(
+ GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
+ // Load lhs[] and rhs[], zero out internal accumulators.
+ "ld.b $w16, 0(%[lhs_ptr])\n"
+ "ldi.b $w0, 0\n"
+ "ld.b $w20, 0(%[rhs_ptr])\n"
+ "ldi.b $w1, 0\n"
+ "ld.b $w17, 16(%[lhs_ptr])\n"
+ "ldi.b $w2, 0\n"
+ "ld.b $w21, 16(%[rhs_ptr])\n"
+ "ldi.b $w3, 0\n"
+ "ld.b $w18, 32(%[lhs_ptr])\n"
+ "ldi.b $w4, 0\n"
+ "ld.b $w19, 48(%[lhs_ptr])\n"
+ "ldi.b $w5, 0\n"
+ "ld.b $w22, 32(%[rhs_ptr])\n"
+ "ldi.b $w6, 0\n"
+ "ld.b $w23, 48(%[rhs_ptr])\n"
+ "ldi.b $w7, 0\n"
+ "ldi.b $w8, 0\n"
+ "ldi.b $w9, 0\n"
+ "ldi.b $w10, 0\n"
+ "ldi.b $w11, 0\n"
+ "ldi.b $w12, 0\n"
+ "ldi.b $w13, 0\n"
+ "ldi.b $w14, 0\n"
+ "ldi.b $w15, 0\n"
+ "ldi.h $w31, 1\n"
+ // If the loop depth is only 16, then we can skip the general loop
+ // and go straight to the final part of the code.
+ "beqz %[run_depth], " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A 4x16 block of Rhs is stored in 8 bit in w16-w19.
+ // A 4x16 block of Lhs is stored in 8 bit in w20-w23.
+ //
+ // A 4x4 block of accumulators is stored in w0-w15 (as 4x32 bit
+ // components which need to be horizontally added at the end).
+ //
+ // Dot products of Lhs and Rhs are 16-bit values, which can't
+ // immediately be accumulated in 32-bit accumulators by that
+ // same instruction that calculates them.
+ // For example, "dotp_s.h $w25, $w16, $w20" produces 8 16-bit
+ // sums in w25 (note, the 16 sums have already been reduced to 8
+ // by the horizontal addition of the dotp instruction).
+ // They are then sign-extended to 32 bits, horizontally added
+ // (again) to form 4 32-bit sums and then they are finally added
+ // to the 32-bit accumulators, all by "dpadd_s.w $w0, $w25, $w31".
+ //
+ // +-----+-----+-----+-----+
+ // Rhs | w20 | w21 | w22 | w23 |
+ // +-----+-----+-----+-----+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w16| | w0 | w4 | w8 | w12 |
+ // |w17| | w1 | w5 | w9 | w13 |
+ // |w18| | w2 | w6 | w10 | w14 |
+ // |w19| | w3 | w7 | w11 | w15 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ //
+ // Accumulators
+
+ // Calculate the results for 16 depths and load
+ // lhs[] and rhs[] for the next iteration.
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 64\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 64\n"
+ GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w25, $w16, $w20\n"
+ "dotp_s.h $w26, $w17, $w20\n"
+ "dotp_s.h $w27, $w16, $w21\n"
+ "dotp_s.h $w28, $w17, $w21\n"
+ "dotp_s.h $w29, $w18, $w20\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w0, $w25, $w31\n"
+ "dpadd_s.w $w1, $w26, $w31\n"
+ "dpadd_s.w $w4, $w27, $w31\n"
+ "dpadd_s.w $w5, $w28, $w31\n"
+ "dpadd_s.w $w2, $w29, $w31\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w24, $w16, $w22\n"
+ "dotp_s.h $w25, $w19, $w20\n"
+ "dotp_s.h $w26, $w16, $w23\n"
+ "dotp_s.h $w27, $w17, $w22\n"
+ "ld.b $w20, 0(%[rhs_ptr])\n"
+ "dotp_s.h $w28, $w17, $w23\n"
+ "ld.b $w16, 0(%[lhs_ptr])\n"
+ "dotp_s.h $w29, $w18, $w21\n"
+ "ld.b $w17, 16(%[lhs_ptr])\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w8, $w24, $w31\n"
+ "dpadd_s.w $w3, $w25, $w31\n"
+ "dpadd_s.w $w12, $w26, $w31\n"
+ "dpadd_s.w $w9, $w27, $w31\n"
+ "dpadd_s.w $w13, $w28, $w31\n"
+ "dpadd_s.w $w6, $w29, $w31\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w25, $w19, $w21\n"
+ "dotp_s.h $w26, $w18, $w22\n"
+ "dotp_s.h $w27, $w18, $w23\n"
+ "ld.b $w21, 16(%[rhs_ptr])\n"
+ "dotp_s.h $w28, $w19, $w22\n"
+ "ld.b $w18, 32(%[lhs_ptr])\n"
+ "dotp_s.h $w29, $w19, $w23\n"
+ "ld.b $w22, 32(%[rhs_ptr])\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w7, $w25, $w31\n"
+ "ld.b $w19, 48(%[lhs_ptr])\n"
+ "dpadd_s.w $w10, $w26, $w31\n"
+ "ld.b $w23, 48(%[rhs_ptr])\n"
+ "dpadd_s.w $w14, $w27, $w31\n"
+ "dpadd_s.w $w11, $w28, $w31\n"
+ "dpadd_s.w $w15, $w29, $w31\n"
+
+ "bnez %[run_depth], " GEMMLOWP_LABEL_LOOP "b\n"
+
+ GEMMLOWP_LABEL_AFTER_LOOP_LAST16 ":\n"
+ // Calculate the results for the last 16 depths.
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w25, $w16, $w20\n"
+ "dotp_s.h $w26, $w17, $w20\n"
+ "dotp_s.h $w27, $w16, $w21\n"
+ "dotp_s.h $w28, $w17, $w21\n"
+ "dotp_s.h $w29, $w18, $w20\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w0, $w25, $w31\n"
+ "dpadd_s.w $w1, $w26, $w31\n"
+ "dpadd_s.w $w4, $w27, $w31\n"
+ "dpadd_s.w $w5, $w28, $w31\n"
+ "dpadd_s.w $w2, $w29, $w31\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w24, $w16, $w22\n"
+ "dotp_s.h $w25, $w19, $w20\n"
+ "dotp_s.h $w26, $w16, $w23\n"
+ "dotp_s.h $w27, $w17, $w22\n"
+ "dotp_s.h $w28, $w17, $w23\n"
+ "dotp_s.h $w29, $w18, $w21\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w8, $w24, $w31\n"
+ "dpadd_s.w $w3, $w25, $w31\n"
+ "dpadd_s.w $w12, $w26, $w31\n"
+ "dpadd_s.w $w9, $w27, $w31\n"
+ "dpadd_s.w $w13, $w28, $w31\n"
+ "dpadd_s.w $w6, $w29, $w31\n"
+
+ // Dot product: multiply-add pairs of adjacent int8 elements.
+ // Each dot product takes 16*2 int8 values in and produces 8 int16 sums.
+ "dotp_s.h $w25, $w19, $w21\n"
+ "dotp_s.h $w26, $w18, $w22\n"
+ "dotp_s.h $w27, $w18, $w23\n"
+ "dotp_s.h $w28, $w19, $w22\n"
+ "dotp_s.h $w29, $w19, $w23\n"
+ // Horizontal add of pairs of adjacent int16 sums into internal int32
+ // accumulators.
+ "dpadd_s.w $w7, $w25, $w31\n"
+ "dpadd_s.w $w10, $w26, $w31\n"
+ "dpadd_s.w $w14, $w27, $w31\n"
+ "dpadd_s.w $w11, $w28, $w31\n"
+ "dpadd_s.w $w15, $w29, $w31\n"
+
+ // Horizontal-add internal accumulators.
+ "hadd_s.d $w0, $w0, $w0\n"
+ "hadd_s.d $w1, $w1, $w1\n"
+ "hadd_s.d $w2, $w2, $w2\n"
+ "hadd_s.d $w3, $w3, $w3\n"
+ "hadd_s.d $w4, $w4, $w4\n"
+ "hadd_s.d $w5, $w5, $w5\n"
+ "hadd_s.d $w6, $w6, $w6\n"
+ "hadd_s.d $w7, $w7, $w7\n"
+ "hadd_s.d $w8, $w8, $w8\n"
+ "hadd_s.d $w9, $w9, $w9\n"
+ "hadd_s.d $w10, $w10, $w10\n"
+ "hadd_s.d $w11, $w11, $w11\n"
+ "hadd_s.d $w12, $w12, $w12\n"
+ "hadd_s.d $w13, $w13, $w13\n"
+ "hadd_s.d $w14, $w14, $w14\n"
+ "hadd_s.d $w15, $w15, $w15\n"
+ "pckev.w $w0, $w1, $w0\n"
+ "pckev.w $w2, $w3, $w2\n"
+ "pckev.w $w4, $w5, $w4\n"
+ "pckev.w $w6, $w7, $w6\n"
+ "pckev.w $w8, $w9, $w8\n"
+ "pckev.w $w10, $w11, $w10\n"
+ "pckev.w $w12, $w13, $w12\n"
+ "pckev.w $w14, $w15, $w14\n"
+ "hadd_s.d $w0, $w0, $w0\n"
+ "hadd_s.d $w2, $w2, $w2\n"
+ "hadd_s.d $w4, $w4, $w4\n"
+ "hadd_s.d $w6, $w6, $w6\n"
+ "hadd_s.d $w8, $w8, $w8\n"
+ "hadd_s.d $w10, $w10, $w10\n"
+ "hadd_s.d $w12, $w12, $w12\n"
+ "hadd_s.d $w14, $w14, $w14\n"
+ // 4 more pckev instructions follow in both paths below.
+
+ // Check if start_depth==0 to decide whether we will load
+ // existing accumulators from memory.
+ "bnez %[start_depth], " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "f\n"
+
+ "pckev.w $w0, $w2, $w0\n"
+ "pckev.w $w1, $w6, $w4\n"
+ "pckev.w $w2, $w10, $w8\n"
+ "pckev.w $w3, $w14, $w12\n"
+
+ "b " GEMMLOWP_LABEL_STORE "f\n"
+
+ GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES ":\n"
+ // Load accumulators from memory.
+ "ld.w $w16, 0(%[dst_ptr0])\n"
+ "pckev.w $w0, $w2, $w0\n"
+ "ld.w $w17, 0(%[dst_ptr1])\n"
+ "pckev.w $w1, $w6, $w4\n"
+ "ld.w $w18, 0(%[dst_ptr2])\n"
+ "pckev.w $w2, $w10, $w8\n"
+ "ld.w $w19, 0(%[dst_ptr3])\n"
+ "pckev.w $w3, $w14, $w12\n"
+
+ // Add them to internal accumulators.
+ "addv.w $w0, $w0, $w16\n"
+ "addv.w $w1, $w1, $w17\n"
+ "addv.w $w2, $w2, $w18\n"
+ "addv.w $w3, $w3, $w19\n"
+
+ GEMMLOWP_LABEL_STORE ":\n"
+ // Store accumulators.
+ "st.w $w0, 0(%[dst_ptr0])\n"
+ "st.w $w1, 0(%[dst_ptr1])\n"
+ "st.w $w2, 0(%[dst_ptr2])\n"
+ "st.w $w3, 0(%[dst_ptr3])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [run_depth] "+r"(run_depth)
+ : // inputs
+ [dst_ptr0] "r"(dst_ptr), [dst_ptr1] "r"(dst_ptr + dst_col_stride),
+ [dst_ptr2] "r"(dst_ptr + dst_col_stride * 2),
+ [dst_ptr3] "r"(dst_ptr + dst_col_stride * 3),
+ [start_depth] "r"(start_depth)
+ : // clobbers
+ "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8",
+ "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17",
+ "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26",
+ "$f27", "$f28", "$f29", "$f30", "$f31");
+#undef GEMMLOWP_LABEL_LOOP
+#undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16
+#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES
+#undef GEMMLOWP_LABEL_STORE
+ }
+};
+
#undef GEMMLOWP_MIPS_XADDU
#undef GEMMLOWP_MIPS_XADDIU
#undef GEMMLOWP_MIPS_XSLL
diff --git a/internal/kernel_neon.h b/internal/kernel_neon.h
index 3cd48f4..9859637 100644
--- a/internal/kernel_neon.h
+++ b/internal/kernel_neon.h
@@ -55,6 +55,7 @@ struct NEON_32_Kernel12x4Depth2 : KernelBase {
#define GEMMLOWP_LABEL_AFTER_LOOP "4"
assert(dst_row_stride == 1);
+ (void)dst_row_stride;
asm volatile(
// Overview of register layout:
//
@@ -308,6 +309,7 @@ struct NEON_32_Kernel12x4Depth2Assuming12BitProducts : KernelBase {
ScopedProfilingLabel label(
"optimized kernel (NEON 12x4, assuming 12-bit products)");
assert(dst_row_stride == 1);
+ (void)dst_row_stride;
// See comments above for why we need local numerical labels in our asm.
#define GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS "1"
@@ -678,6 +680,7 @@ struct NEON_32bit_GEMM_Int8Operands_LhsNonzero : KernelBase {
std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
const std::uint8_t* rhs_ptr, std::size_t start_depth,
std::size_t run_depth) const override {
+ (void)dst_row_stride;
#define GEMMLOWP_LABEL_AFTER_LOOP "1"
#define GEMMLOWP_LABEL_LOOP "2"
#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
@@ -921,6 +924,17 @@ struct NEON_32bit_GEMM_Int8Operands_LhsNonzero : KernelBase {
}
};
+// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that
+// requires that user inputs were originally int8. This avoids the uint8->int8
+// conversion in the pack step.
+struct NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs
+ : NEON_32bit_GEMM_Int8Operands_LhsNonzero {
+ typedef KernelFormat<
+ KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+ KernelSideFormatInt8Inputs<CellFormat<2, 16, CellOrder::WidthMajor>, 1> >
+ Format;
+};
+
#endif // GEMMLOWP_NEON_32
// The kernels here are specifically arm 64bit assembly, not arm 32bit.
@@ -940,6 +954,7 @@ struct NEON_64bit_GEMM_Int8Operands_LhsNonzero : KernelBase {
std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
const std::uint8_t* rhs_ptr, std::size_t start_depth,
std::size_t run_depth) const override {
+ (void)dst_row_stride;
#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1"
#define GEMMLOWP_LABEL_LOOP "2"
#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3"
@@ -1261,6 +1276,17 @@ struct NEON_64bit_GEMM_Int8Operands_LhsNonzero : KernelBase {
}
};
+// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that
+// requires that user inputs were originally int8. This avoids the uint8->int8
+// conversion in the pack step.
+struct NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs
+ : NEON_64bit_GEMM_Int8Operands_LhsNonzero {
+ typedef KernelFormat<
+ KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>,
+ KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1> >
+ Format;
+};
+
// Our main GEMM kernel.
struct NEON_64_Kernel12x8Depth2 : KernelBase {
typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
@@ -1274,6 +1300,7 @@ struct NEON_64_Kernel12x8Depth2 : KernelBase {
std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
const std::uint8_t* rhs_ptr, std::size_t start_depth,
std::size_t run_depth) const override {
+ (void)dst_row_stride;
ScopedProfilingLabel label("optimized kernel (NEON 12x8)");
// See comments above for why we need local numerical labels in our asm.
#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
@@ -1611,6 +1638,274 @@ struct NEON_64_Kernel12x8Depth2 : KernelBase {
}
};
+#ifdef GEMMLOWP_DOTPROD_KERNEL
+#ifndef __ARM_FEATURE_DOTPROD
+#error This kernel requires ARM dot-product instructions. Enable them by \
+ adding '+dotprod' to a compiler flag, e.g. -march=armv8.2-a+dotprod . \
+ Note that Clang up to version 7 fails to define the corresponding \
+ preprocessor token __ARM_FEATURE_DOTPROD, so you will still have to define \
+ it manually.
+#endif
+// Kernels utilizing the Armv8.2 Dot Product extension.
+//
+// The dot product instructions work by taking 4 consecutive 8-bit depth
+// values from each operand, multiplying the 4 pairs together and
+// accumulating all the results into the corresponding 32-bit accumulator
+// lane. As such, the operation is identical to a 32-bit instruction (like
+// FMLA used in SGEMM), except that 4 depth values are processed at a time
+// instead of 1.
+
+// Thus, this first kernel is a carbon copy of
+// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good
+// performance for most processors) below with the opcode (fmla -> udot) and
+// types (float32 -> uint8/uint32) changed.
+//
+// A signed version of this kernel could be produced by replacing "udot"
+// with "sdot" - performance should be identical to this udot kernel.
+struct NEON_64_Kernel12x8Depth4_dotprod : KernelBase {
+ typedef KernelFormat<KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
+ Format;
+
+ const char* Name() const override { return "NEON, 12x8, depth 4, dotprod"; }
+
+ void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, std::size_t dst_col_stride,
+ const std::uint8_t* lhs_ptr, const std::uint8_t* rhs_ptr, std::size_t start_depth,
+ std::size_t depth) const override {
+ (void)dst_row_stride;
+ ScopedProfilingLabel label("optimized kernel (NEON 12x8, depth 4, dotprod)");
+// See comments above for why we need local numerical labels in our asm.
+#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
+#define GEMMLOWP_LABEL_BEFORE_LOOP "2"
+#define GEMMLOWP_LABEL_LOOP "3"
+#define GEMMLOWP_LABEL_AFTER_LOOP "4"
+
+ assert(dst_row_stride == 1);
+ asm volatile(
+ // Multiply dst_col_stride by 4 == sizeof(int32) to use
+ // it as a byte offset below.
+ "lsl %[dst_col_stride], %[dst_col_stride], #2\n"
+
+ "cmp %[start_depth], #0\n"
+ "beq " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
+
+ // Load accumulators
+ "mov x1, %[dst_ptr]\n"
+ "mov x0, x1\n"
+ "ld1 {v8.16b}, [x0], #16\n"
+ "ld1 {v16.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "ld1 {v24.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "ld1 {v9.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "ld1 {v17.16b}, [x0], #16\n"
+ "ld1 {v25.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "ld1 {v10.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "ld1 {v18.16b}, [x0], #16\n"
+ "ld1 {v26.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "ld1 {v11.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "ld1 {v19.16b}, [x0], #16\n"
+ "ld1 {v27.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "ld1 {v12.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "ld1 {v20.16b}, [x0], #16\n"
+ "ld1 {v28.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "ld1 {v13.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "ld1 {v21.16b}, [x0], #16\n"
+ "ld1 {v29.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "ld1 {v14.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "ld1 {v22.16b}, [x0], #16\n"
+ "ld1 {v30.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "ld1 {v15.16b}, [x0], #16\n"
+ "ld1 {v23.16b}, [x0], #16\n"
+ "ld1 {v31.16b}, [x0]\n"
+
+ "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
+
+ GEMMLOWP_LABEL_CLEAR_ACCUMULATORS ":\n"
+
+ // Clear accumulator registers (see layout below)
+ "dup v8.4s, wzr\n"
+ "dup v9.4s, wzr\n"
+ "dup v10.4s, wzr\n"
+ "dup v11.4s, wzr\n"
+ "dup v12.4s, wzr\n"
+ "dup v13.4s, wzr\n"
+ "dup v14.4s, wzr\n"
+ "dup v15.4s, wzr\n"
+ "dup v16.4s, wzr\n"
+ "dup v17.4s, wzr\n"
+ "dup v18.4s, wzr\n"
+ "dup v19.4s, wzr\n"
+ "dup v20.4s, wzr\n"
+ "dup v21.4s, wzr\n"
+ "dup v22.4s, wzr\n"
+ "dup v23.4s, wzr\n"
+ "dup v24.4s, wzr\n"
+ "dup v25.4s, wzr\n"
+ "dup v26.4s, wzr\n"
+ "dup v27.4s, wzr\n"
+ "dup v28.4s, wzr\n"
+ "dup v29.4s, wzr\n"
+ "dup v30.4s, wzr\n"
+ "dup v31.4s, wzr\n"
+
+ GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
+
+ "subs %w[depth], %w[depth], #4\n"
+
+ // The start of the loop assumes first Rhs cell is already loaded, so
+ // do it here for first iteration.
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+
+ // And the same for the first Lhs cell.
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+
+ "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+
+ // Start the MACs at the head of the loop - 1st cell from each side
+ // already loaded.
+ ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n"
+ ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n"
+ "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
+ ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n"
+ ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
+ ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n"
+ ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
+ ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n"
+ ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
+ // for the next iteration early.
+ ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n"
+ ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n"
+ ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n"
+ ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n"
+ ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n"
+ ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n"
+ ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n"
+ ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n"
+ ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n"
+ ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n"
+ ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n"
+ ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n"
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
+ // load for the next iteration early.
+ ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n"
+ ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n"
+
+ // Loop. Decrement loop index (depth) by 4 as udot processes 4
+ // depth values.
+ "subs %w[depth], %w[depth], #4\n"
+ ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n"
+ ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n"
+
+ "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+ GEMMLOWP_LABEL_AFTER_LOOP ":\n"
+
+ // Final iteration. v0 and v2 were already loaded, don't load
+ // them again, don't read past the end of buffers.
+ ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n"
+ ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n"
+ "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
+ ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n"
+ ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
+ ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n"
+ ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
+ ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n"
+ ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n"
+ ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n"
+ ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n"
+ ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n"
+ ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n"
+ ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n"
+ ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n"
+ ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n"
+ ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n"
+ ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n"
+ ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n"
+ ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n"
+ ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n"
+ ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n"
+ ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n"
+
+ // Loop. Decrement loop index (depth) by 4 as udot processes 4
+ // depth values.
+ "subs %w[depth], %w[depth], #4\n"
+ ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n"
+ ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n"
+
+ // Store accumulators
+ "mov x1, %[dst_ptr]\n"
+ "mov x0, x1\n"
+ "st1 {v8.16b}, [x0], #16\n"
+ "st1 {v16.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "st1 {v24.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "st1 {v9.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "st1 {v17.16b}, [x0], #16\n"
+ "st1 {v25.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "st1 {v10.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "st1 {v18.16b}, [x0], #16\n"
+ "st1 {v26.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "st1 {v11.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "st1 {v19.16b}, [x0], #16\n"
+ "st1 {v27.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "st1 {v12.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "st1 {v20.16b}, [x0], #16\n"
+ "st1 {v28.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "st1 {v13.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "st1 {v21.16b}, [x0], #16\n"
+ "st1 {v29.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "st1 {v14.16b}, [x0], #16\n"
+ "add x1, x1, %[dst_col_stride]\n"
+ "st1 {v22.16b}, [x0], #16\n"
+ "st1 {v30.16b}, [x0]\n"
+ "mov x0, x1\n"
+ "st1 {v15.16b}, [x0], #16\n"
+ "st1 {v23.16b}, [x0], #16\n"
+ "st1 {v31.16b}, [x0]\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [dst_ptr] "r"(dst_ptr), [dst_col_stride] "r"(dst_col_stride), [start_depth] "r"(start_depth)
+ : // clobbers
+ "cc", "memory", "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
+ "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
+ "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
+ }
+};
+#endif // GEMMLOWP_DOTPROD_KERNEL
+
#endif // GEMMLOWP_NEON_64
} // namespace gemmlowp
diff --git a/internal/kernel_sse.h b/internal/kernel_sse.h
index b879fd7..ba7959b 100644
--- a/internal/kernel_sse.h
+++ b/internal/kernel_sse.h
@@ -43,6 +43,7 @@ struct SSE4_32_Kernel4x4Depth2 : KernelBase {
std::size_t run_depth) const override {
ScopedProfilingLabel label("optimized kernel");
assert(dst_row_stride == 1);
+ (void)dst_row_stride;
std::int32_t run_depth_cells = run_depth / Format::kDepth;
/* Main loop */
@@ -217,6 +218,7 @@ struct SSE4_64_Kernel12x4Depth2 : KernelBase {
std::size_t run_depth) const override {
ScopedProfilingLabel label("optimized kernel");
assert(dst_row_stride == 1);
+ (void)dst_row_stride;
const std::int64_t run_depth_cells = run_depth / Format::kDepth;
const std::int64_t dst_col_stride_q = dst_col_stride;
diff --git a/internal/multi_thread_gemm.h b/internal/multi_thread_gemm.h
index 791402f..97183e7 100644
--- a/internal/multi_thread_gemm.h
+++ b/internal/multi_thread_gemm.h
@@ -19,23 +19,43 @@
#ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
#define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
+#include <atomic> // NOLINT
+#include <chrono> // NOLINT
+#include <thread> // NOLINT
#include <vector>
#include "single_thread_gemm.h"
namespace gemmlowp {
-// On X86 and ARM platforms we enable a busy-wait spinlock before waiting on a
-// pthread conditional variable. In order to implement that correctly we need
-// to put some explicit memory load/store barriers.
+// This value was empirically derived on an end-to-end application benchmark.
+// That this number of cycles means that we may be sleeping substantially longer
+// than a scheduler timeslice's duration is not necessarily surprising. The
+// idea is to pick up quickly new work after having finished the previous
+// workload. When it's new work within the same GEMM as the previous work, the
+// time interval that we might be busy-waiting is very small, so for that
+// purpose it would be more than enough to sleep for 1 million cycles.
+// That is all what we would observe on a GEMM benchmark. However, in a real
+// application, after having finished a GEMM, we might do unrelated work for
+// a little while, then start on a new GEMM. Think of a neural network
+// application performing inference, where many but not all layers are
+// implemented by a GEMM. In such cases, our worker threads might be idle for
+// longer periods of time before having work again. If we let them passively
+// wait, on a mobile device, the CPU scheduler might aggressively clock down
+// or even turn off the CPU cores that they were running on. That would result
+// in a long delay the next time these need to be turned back on for the next
+// GEMM. So we need to strike a balance that reflects typical time intervals
+// between consecutive GEMM invokations, not just intra-GEMM considerations.
+// Of course, we need to balance keeping CPUs spinning longer to resume work
+// faster, versus passively waiting to conserve power.
+const int kMaxBusyWaitNOPs = 4 * 1000 * 1000;
+
+// On X86 and ARM platforms we may use NOP instructions to know how long we
+// are busy-waiting.
#if defined(GEMMLOWP_ALLOW_INLINE_ASM) && !defined(GEMMLOWP_NO_BUSYWAIT) && \
(defined(GEMMLOWP_ARM) || defined(GEMMLOWP_X86))
-#define GEMMLOWP_USE_BUSYWAIT
-
-const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
-
#define GEMMLOWP_NOP "nop\n"
#define GEMMLOWP_STRING_CONCAT_4(X) X X X X
@@ -43,46 +63,26 @@ const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
#define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
#define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
-inline int Do256NOPs() {
+inline int DoSomeNOPs() {
asm volatile(GEMMLOWP_NOP64);
return 64;
}
#undef GEMMLOWP_STRING_CONCAT_4
-#undef GEMMLOWP_NOP256
#undef GEMMLOWP_NOP64
#undef GEMMLOWP_NOP16
#undef GEMMLOWP_NOP4
#undef GEMMLOWP_NOP
-inline void WriteBarrier() {
-#if defined(_MSC_VER)
- MemoryBarrier();
-#elif defined(GEMMLOWP_ARM_32)
- asm volatile("" ::: "memory");
-#elif defined(GEMMLOWP_ARM_64)
- asm volatile("dmb ishst" ::: "memory");
-#elif defined(GEMMLOWP_X86)
- asm volatile("sfence" ::: "memory");
-#else
-#error "Unsupported architecture for WriteBarrier."
-#endif
-}
+#else // May not use asm NOP.
-inline void ReadBarrier() {
-#if defined(_MSC_VER)
- MemoryBarrier();
-#elif defined(GEMMLOWP_ARM_32)
- asm volatile("" ::: "memory");
-#elif defined(GEMMLOWP_ARM_64)
- asm volatile("dmb ishld" ::: "memory");
-#elif defined(GEMMLOWP_X86)
- asm volatile("lfence" ::: "memory");
-#else
-#error "Unsupported architecture for ReadBarrier."
-#endif
+// If we can't use NOPs, let's use a non-inline function call as a basic
+// thing that has some vaguely known, nonzero cost.
+GEMMLOWP_NOINLINE
+inline int DoSomeNOPs() {
+ // Pretend that calling an empty function takes as long as 16 NOPs...
+ return 16;
}
-
#endif
// Waits until *var != initial_value.
@@ -108,37 +108,29 @@ inline void ReadBarrier() {
// so as to avoid permanently spinning.
//
template <typename T>
-T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond,
- pthread_mutex_t* mutex) {
-#ifdef GEMMLOWP_USE_BUSYWAIT
- // If we are on a platform that supports it, spin for some time.
- {
- int nops = 0;
- // First, trivial case where the variable already changed value.
- T new_value = *var;
+T WaitForVariableChange(std::atomic<T>* var, T initial_value,
+ pthread_cond_t* cond, pthread_mutex_t* mutex) {
+ // First, trivial case where the variable already changed value.
+ T new_value = var->load(std::memory_order_acquire);
+ if (new_value != initial_value) {
+ return new_value;
+ }
+ // Then try busy-waiting.
+ int nops = 0;
+ while (nops < kMaxBusyWaitNOPs) {
+ nops += DoSomeNOPs();
+ new_value = var->load(std::memory_order_acquire);
if (new_value != initial_value) {
- ReadBarrier();
return new_value;
}
- // Then try busy-waiting.
- while (nops < kMaxBusyWaitNOPs) {
- nops += Do256NOPs();
- new_value = *var;
- if (new_value != initial_value) {
- ReadBarrier();
- return new_value;
- }
- }
}
-#endif
// Finally, do real passive waiting.
pthread_mutex_lock(mutex);
- T new_value = *var;
- if (new_value == initial_value) {
+ new_value = var->load(std::memory_order_acquire);
+ while (new_value == initial_value) {
pthread_cond_wait(cond, mutex);
- new_value = *var;
- assert(new_value != initial_value);
+ new_value = var->load(std::memory_order_acquire);
}
pthread_mutex_unlock(mutex);
return new_value;
@@ -147,73 +139,74 @@ T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond,
// A BlockingCounter lets one thread to wait for N events to occur.
// This is how the master thread waits for all the worker threads
// to have finished working.
+// The waiting is done using a naive spinlock waiting for the atomic
+// count_ to hit the value 0. This is acceptable because in our usage
+// pattern, BlockingCounter is used only to synchronize threads after
+// short-lived tasks (performing parts of the same GEMM). It is not used
+// for synchronizing longer waits (resuming work on the next GEMM).
class BlockingCounter {
public:
- BlockingCounter() : count_(0), initial_count_(0) {
- pthread_cond_init(&cond_, nullptr);
- pthread_mutex_init(&mutex_, nullptr);
- }
-
- ~BlockingCounter() {
- pthread_cond_destroy(&cond_);
- pthread_mutex_destroy(&mutex_);
- }
+ BlockingCounter() : count_(0) {}
// Sets/resets the counter; initial_count is the number of
// decrementing events that the Wait() call will be waiting for.
void Reset(std::size_t initial_count) {
- pthread_mutex_lock(&mutex_);
- assert(count_ == 0);
- initial_count_ = initial_count;
- count_ = initial_count_;
- pthread_mutex_unlock(&mutex_);
+ std::size_t old_count_value = count_.load(std::memory_order_relaxed);
+ assert(old_count_value == 0);
+ (void)old_count_value;
+ count_.store(initial_count, std::memory_order_release);
}
// Decrements the counter; if the counter hits zero, signals
- // the thread that was waiting for that, and returns true.
+ // the threads that were waiting for that, and returns true.
// Otherwise (if the decremented count is still nonzero),
// returns false.
bool DecrementCount() {
- pthread_mutex_lock(&mutex_);
- assert(count_ > 0);
- count_--;
-#ifdef GEMMLOWP_USE_BUSYWAIT
- WriteBarrier();
-#endif
- if (count_ == 0) {
- pthread_cond_signal(&cond_);
- }
- bool retval = count_ == 0;
- pthread_mutex_unlock(&mutex_);
- return retval;
+ std::size_t old_count_value =
+ count_.fetch_sub(1, std::memory_order_acq_rel);
+ assert(old_count_value > 0);
+ std::size_t count_value = old_count_value - 1;
+ return count_value == 0;
}
// Waits for the N other threads (N having been set by Reset())
// to hit the BlockingCounter.
void Wait() {
ScopedProfilingLabel label("BlockingCounter::Wait");
- while (count_) {
-#ifdef GEMMLOWP_USE_BUSYWAIT
- ReadBarrier();
-#else
- // This is likely unnecessary, but is kept to ensure regressions are not
- // introduced.
-#ifndef _WIN32
- asm volatile("" ::: "memory");
-#endif
-#endif
- const std::size_t count_value = count_;
- if (count_value) {
- WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
+ // Busy-wait until the count value is 0.
+ int nops = 0;
+ while (count_.load(std::memory_order_acquire)) {
+ nops += DoSomeNOPs();
+ if (nops > kMaxBusyWaitNOPs) {
+ nops = 0;
+ // If we are unlucky, the blocking thread (that calls DecrementCount)
+ // and the blocked thread (here, calling Wait) may be scheduled on
+ // the same CPU, so the busy-waiting of the present thread may prevent
+ // the blocking thread from resuming and unblocking.
+ // If we are even unluckier, the priorities of the present thread
+ // might be higher than that of the blocking thread, so just yielding
+ // wouldn't allow the blocking thread to resume. So we sleep for
+ // a substantial amount of time in that case. Notice that we only
+ // do so after having busy-waited for kMaxBusyWaitNOPs, which is
+ // typically several milliseconds, so sleeping 1 more millisecond
+ // isn't terrible at that point.
+ //
+ // How this is mitigated in practice:
+ // In practice, it is well known that the application should be
+ // conservative in choosing how many threads to tell gemmlowp to use,
+ // as it's hard to know how many CPU cores it will get to run on,
+ // on typical mobile devices.
+ // It seems impossible for gemmlowp to make this choice automatically,
+ // which is why gemmlowp's default is to use only 1 thread, and
+ // applications may override that if they know that they can count on
+ // using more than that.
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
}
private:
- pthread_cond_t cond_;
- pthread_mutex_t mutex_;
- std::size_t count_;
- std::size_t initial_count_;
+ std::atomic<std::size_t> count_;
};
// A workload for a worker.
@@ -253,11 +246,15 @@ class Worker {
// Changes State; may be called from either the worker thread
// or the master thread; however, not all state transitions are legal,
// which is guarded by assertions.
- void ChangeState(State new_state) {
+ //
+ // The Task argument is to be used only with new_state==HasWork.
+ // It specifies the Task being handed to this Worker.
+ void ChangeState(State new_state, Task* task = nullptr) {
ScopedProfilingLabel label("Worker::ChangeState");
pthread_mutex_lock(&state_mutex_);
- assert(new_state != state_);
- switch (state_) {
+ State old_state = state_.load(std::memory_order_relaxed);
+ assert(old_state != new_state);
+ switch (old_state) {
case State::ThreadStartup:
assert(new_state == State::Ready);
break;
@@ -272,18 +269,33 @@ class Worker {
default:
abort();
}
- state_ = new_state;
- pthread_cond_signal(&state_cond_);
- if (state_ == State::Ready) {
- counter_to_decrement_when_ready_->DecrementCount();
+ switch (new_state) {
+ case State::Ready:
+ if (task_) {
+ // Doing work is part of reverting to 'ready' state.
+ task_->Run();
+ task_ = nullptr;
+ }
+ break;
+ case State::HasWork:
+ assert(!task_);
+ task->local_allocator = &local_allocator_;
+ task_ = task;
+ break;
+ default:
+ break;
}
+ state_.store(new_state, std::memory_order_relaxed);
+ pthread_cond_broadcast(&state_cond_);
pthread_mutex_unlock(&state_mutex_);
+ if (new_state == State::Ready) {
+ counter_to_decrement_when_ready_->DecrementCount();
+ }
}
// Thread entry point.
void ThreadFunc() {
ScopedProfilingLabel label("Worker::ThreadFunc");
- RegisterCurrentThreadForProfiling();
ChangeState(State::Ready);
@@ -299,9 +311,6 @@ class Worker {
switch (state_to_act_upon) {
case State::HasWork:
// Got work to do! So do it, and then revert to 'Ready' state.
- assert(task_);
- task_->Run();
- task_ = nullptr;
ChangeState(State::Ready);
break;
case State::ExitAsSoonAsPossible:
@@ -318,17 +327,7 @@ class Worker {
}
// Called by the master thead to give this worker work to do.
- // It is only legal to call this if the worker
- void StartWork(Task* task) {
- assert(!task_);
- task->local_allocator = &local_allocator_;
- task_ = task;
-#ifdef GEMMLOWP_USE_BUSYWAIT
- WriteBarrier();
-#endif
- assert(state_ == State::Ready);
- ChangeState(State::HasWork);
- }
+ void StartWork(Task* task) { ChangeState(State::HasWork, task); }
private:
// The underlying thread.
@@ -342,7 +341,10 @@ class Worker {
pthread_mutex_t state_mutex_;
// The state enum tells if we're currently working, waiting for work, etc.
- State state_;
+ // Its concurrent accesses by the worker and main threads are guarded by
+ // state_mutex_, and can thus use memory_order_relaxed. This still needs
+ // to be a std::atomic because we use WaitForVariableChange.
+ std::atomic<State> state_;
// Each thread had a local allocator so they can allocate temporary
// buffers without blocking each other.
@@ -359,9 +361,7 @@ class Worker {
// waits for all of them to finish.
//
// See MultiThreadGemmContextBase for how other WorkersPool implementations can
-// be used. Note that in those implementations, StartWorker can be free to
-// ignore the <index> value; that is, the caller of WorkersPool does not rely on
-// <index> to order tasks with equal <index>.
+// be used.
class WorkersPool {
public:
WorkersPool() {}
@@ -372,18 +372,41 @@ class WorkersPool {
}
}
- void Execute(const std::vector<Task*>& tasks) {
- assert(tasks.size() >= 1);
+ // Just executes the tasks. Does not destroy them. Similar to
+ // ruy::ThreadPool::Execute.
+ template <typename TaskType>
+ void Execute(int tasks_count, TaskType* tasks) {
+ assert(tasks_count >= 1);
// One of the tasks will be run on the current thread.
- std::size_t workers_count = tasks.size() - 1;
+ std::size_t workers_count = tasks_count - 1;
CreateWorkers(workers_count);
assert(workers_count <= workers_.size());
counter_to_decrement_when_ready_.Reset(workers_count);
- int n = 0;
- std::for_each(tasks.begin(), --tasks.end(),
- [this, &n](Task* task) { workers_[n++]->StartWork(task); });
+ for (std::size_t i = 0; i < tasks_count - 1; i++) {
+ workers_[i]->StartWork(&tasks[i]);
+ }
// Execute the remaining workload immediately on the current thread.
- Task* task = tasks.back();
+ Task* task = &tasks[tasks_count - 1];
+ task->local_allocator = &main_thread_task_allocator_;
+ task->Run();
+ // Wait for the workers submitted above to finish.
+ counter_to_decrement_when_ready_.Wait();
+ }
+
+ // Legacy: executes the tasks and destroys them
+ void LegacyExecuteAndDestroyTasks(const std::vector<Task*>& tasks) {
+ std::size_t tasks_count = tasks.size();
+ assert(tasks_count >= 1);
+ // One of the tasks will be run on the current thread.
+ std::size_t workers_count = tasks_count - 1;
+ CreateWorkers(workers_count);
+ assert(workers_count <= workers_.size());
+ counter_to_decrement_when_ready_.Reset(workers_count);
+ for (int i = 0; i < tasks_count - 1; i++) {
+ workers_[i]->StartWork(tasks[i]);
+ }
+ // Execute the remaining workload immediately on the current thread.
+ Task* task = tasks[tasks_count - 1];
task->local_allocator = &main_thread_task_allocator_;
task->Run();
// Wait for the workers submitted above to finish.
@@ -393,6 +416,11 @@ class WorkersPool {
std::for_each(tasks.begin(), tasks.end(), [](Task* task) { delete task; });
}
+ // Legacy old name of LegacyExecuteAndDestroyTasks
+ void Execute(const std::vector<Task*>& tasks) {
+ LegacyExecuteAndDestroyTasks(tasks);
+ }
+
private:
// Ensures that the pool has at least the given count of workers.
// If any new worker has to be created, this function waits for it to
diff --git a/internal/output.h b/internal/output.h
index dcfe2b5..92bf7b9 100644
--- a/internal/output.h
+++ b/internal/output.h
@@ -22,6 +22,7 @@
#include <cmath>
#include <tuple>
#include <type_traits>
+#include <typeinfo>
#include "../fixedpoint/fixedpoint.h"
#include "../public/output_stages.h"
@@ -179,7 +180,47 @@ struct OutputStageEvalBufferImpl<OutputStageScaleInt32ByFixedPointAndExponent,
int right_shift;
};
-// Implementation of OutputStageSaturatingCastToUint8 for scalar data
+template <int Rows, int Cols, VectorShape Shape>
+struct OutputStageEvalImpl<
+ OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>,
+ RegisterBlock<std::int32_t, Rows, Cols>> {
+ typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
+ typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
+
+ typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> OutputStage;
+
+ OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
+
+ OutputType Eval(InputType input, int row, int col) const {
+ OutputType output;
+ const int pos = Shape == VectorShape::Row ? col : row;
+ using RegisterType = typename InputType::RegisterType;
+ const RegisterType result_offset_after_shift =
+ Dup<RegisterType>(output_stage.result_offset_after_shift);
+ auto left_shift =
+ LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
+ auto right_shift =
+ LoadForBroadcasting<InputType>(output_stage.result_exponent, pos);
+ const auto result_fixedpoint_multiplier = LoadForBroadcasting<InputType>(
+ output_stage.result_fixedpoint_multiplier, pos);
+ for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) {
+ left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0);
+ right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0);
+ }
+ const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul(
+ BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier);
+ const auto rdpot_val =
+ BroadcastRoundingDivideByPOT(mulhigh_val, right_shift);
+ for (int i = 0; i < InputType::kRegisterCount; i++) {
+ output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift);
+ }
+ return output;
+ }
+
+ const OutputStage& output_stage;
+};
+
+// Implementation of OutputStageSaturatingCastToUint8 for scalar data.
template <int Size>
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
RegisterBuffer<std::int32_t, Size>> {
@@ -202,7 +243,30 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
}
};
-// Implementation of OutputStageSaturatingCastToInt16 for scalar data
+// Implementation of OutputStageSaturatingCastToInt8 for scalar data.
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+ RegisterBuffer<std::int32_t, Size>> {
+ typedef RegisterBuffer<std::int32_t, Size> InputType;
+ typedef RegisterBuffer<std::int8_t, Size> OutputType;
+ static_assert(InputType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ for (int i = 0; i < InputType::kRegisterCount; i++) {
+ std::int32_t data = input.reg[i];
+ output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data;
+ }
+ return output;
+ }
+};
+
+// Implementation of OutputStageSaturatingCastToInt16 for scalar data.
template <int Size>
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
RegisterBuffer<std::int32_t, Size>> {
@@ -225,6 +289,28 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
}
};
+// Implementation of OutputStageTruncatingCastToUint8 for scalar data
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+ RegisterBuffer<std::int32_t, Size>> {
+ typedef RegisterBuffer<std::int32_t, Size> InputType;
+ typedef RegisterBuffer<std::uint8_t, Size> OutputType;
+ static_assert(InputType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ for (int i = 0; i < InputType::kRegisterCount; i++) {
+ output.reg[i] = input.reg[i];
+ }
+ return output;
+ }
+};
+
template <int Rows, int Cols, typename VectorType>
struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>,
RegisterBlock<std::int32_t, Rows, Cols>> {
@@ -452,7 +538,7 @@ struct OutputPipelineExecutor {
OutputPipelineExecutor(const OutputPipelineType& output_pipeline)
: output_pipeline_eval_impl_(output_pipeline) {}
- // RunOutputPipeline is the entry point into the output pipeline evaluation
+ // Execute is the entry point into the output pipeline evaluation
// code. It should be the only thing that unpack code calls. It takes the
// result
// of the unpack stage and stores it into the destination matrix.
diff --git a/internal/output_avx.h b/internal/output_avx.h
new file mode 100644
index 0000000..b8f94fb
--- /dev/null
+++ b/internal/output_avx.h
@@ -0,0 +1,19 @@
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// output_avx.h: optimized AVX 2 specializations of the templates in output.h.
+
+#ifndef GEMMLOWP_INTERNAL_OUTPUT_AVX_H_
+#define GEMMLOWP_INTERNAL_OUTPUT_AVX_H_
+
+#endif // GEMMLOWP_INTERNAL_OUTPUT_AVX_H_
diff --git a/internal/output_msa.h b/internal/output_msa.h
index 4c8eb5d..0540bb3 100644
--- a/internal/output_msa.h
+++ b/internal/output_msa.h
@@ -38,18 +38,14 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
// Signed saturate each 32-bit element to 9 bits
// (this takes full care of non-negative elements).
v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8);
+ // Zero out negative elements.
+ tmp = __builtin_msa_maxi_s_w(tmp, 0);
// Pack every 32-bit element into 16 bits.
tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp)));
- // Detect negative elements with arithmetic shift right (we
- // get a 16-bit mask of all zeroes or all ones for every element).
- v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15);
- // Zero out negative elements.
- signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
- reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0));
// Pack every element into 8 bits.
tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
- reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
+ reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp)));
// Return 4 uint8_t elements as uint32_t.
output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
return output;
@@ -76,15 +72,12 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
// combining all 8 elements into one vector.
tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo)));
- // Detect negative elements with arithmetic shift right (we
- // get a 16-bit mask of all zeroes or all ones for every element).
- v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15);
// Zero out negative elements.
- signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
- reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0));
+ tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h(
+ reinterpret_cast<v8i16>(tmp_lo), 0));
// Pack every element into 8 bits.
tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
- reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
+ reinterpret_cast<v16i8>(tmp_lo), reinterpret_cast<v16i8>(tmp_lo)));
// Return 8 uint8_t elements as 2 uint32_t's.
output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0);
output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1);
@@ -102,15 +95,13 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0))); \
tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \
reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2))); \
- v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15); \
- v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15); \
- signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \
- reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \
- signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \
- reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \
- signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b( \
- reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0))); \
- out = reinterpret_cast<v16i8>(signs0); \
+ tmp0 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( \
+ reinterpret_cast<v8i16>(tmp0), 0)); \
+ tmp2 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( \
+ reinterpret_cast<v8i16>(tmp2), 0)); \
+ tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( \
+ reinterpret_cast<v16i8>(tmp2), reinterpret_cast<v16i8>(tmp0))); \
+ out = reinterpret_cast<v16i8>(tmp0); \
}
template <>
@@ -166,8 +157,8 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
OutputType Eval(InputType input) const {
OutputType output;
// Signed saturate each 32-bit element to 16 bits.
- v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(
- input.reg[0], 15));
+ v8i16 tmp =
+ reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(input.reg[0], 15));
output.reg[0] = __builtin_msa_copy_s_h(tmp, 0);
output.reg[1] = __builtin_msa_copy_s_h(tmp, 2);
output.reg[2] = __builtin_msa_copy_s_h(tmp, 4);
@@ -176,12 +167,12 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
}
};
-#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \
- { \
- v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \
- v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \
- out = __builtin_msa_pckev_h( \
- reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \
+#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \
+ { \
+ v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \
+ v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \
+ out = __builtin_msa_pckev_h(reinterpret_cast<v8i16>(tmp1), \
+ reinterpret_cast<v8i16>(tmp0)); \
}
template <>
@@ -241,6 +232,117 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
#undef GEMMLOWP_MIPS_SAT_I16_8
+template <>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferUint8<4> OutputType;
+
+ typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Pack every 32-bit element into 16 bits.
+ v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(input.reg[0]),
+ reinterpret_cast<v8i16>(input.reg[0])));
+ // Pack every element into 8 bits.
+ tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
+ reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp)));
+ // Return 4 uint8_t elements as uint32_t.
+ output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferUint8<8> OutputType;
+
+ typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Pack every 32-bit element into 16 bits.
+ v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(input.reg[1]),
+ reinterpret_cast<v8i16>(input.reg[0])));
+ // Pack every element into 8 bits.
+ tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
+ reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp)));
+ // Return 8 uint8_t elements as 2 uint32_t's.
+ output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
+ output.reg[1] = __builtin_msa_copy_s_w(tmp, 1);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferUint8<16> OutputType;
+
+ typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Pack every 32-bit element into 16 bits.
+ v8i16 tmp0 = __builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(input.reg[1]),
+ reinterpret_cast<v8i16>(input.reg[0]));
+ v8i16 tmp1 = __builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(input.reg[3]),
+ reinterpret_cast<v8i16>(input.reg[2]));
+ // Pack every element into 8 bits.
+ output.reg[0] = __builtin_msa_pckev_b(
+ reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0));
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferUint8<32> OutputType;
+
+ typedef OutputStageTruncatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Pack every 32-bit element into 16 bits.
+ v8i16 tmp0 = __builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(input.reg[1]),
+ reinterpret_cast<v8i16>(input.reg[0]));
+ v8i16 tmp1 = __builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(input.reg[3]),
+ reinterpret_cast<v8i16>(input.reg[2]));
+ v8i16 tmp2 = __builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(input.reg[5]),
+ reinterpret_cast<v8i16>(input.reg[4]));
+ v8i16 tmp3 = __builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(input.reg[7]),
+ reinterpret_cast<v8i16>(input.reg[6]));
+ // Pack every element into 8 bits.
+ output.reg[0] = __builtin_msa_pckev_b(
+ reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0));
+ output.reg[1] = __builtin_msa_pckev_b(
+ reinterpret_cast<v16i8>(tmp3), reinterpret_cast<v16i8>(tmp2));
+ return output;
+ }
+};
+
template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
@@ -474,50 +576,50 @@ struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
}
} else {
// top-left 4x4
- v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1],
- src.buf.reg[0]));
- v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3],
- src.buf.reg[2]));
+ v4i32 t0 = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvr_h(src.buf.reg[1], src.buf.reg[0]));
+ v4i32 t1 = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvr_h(src.buf.reg[3], src.buf.reg[2]));
v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0));
v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0));
// top-right 4x4
- v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5],
- src.buf.reg[4]));
- v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7],
- src.buf.reg[6]));
+ v4i32 t2 = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvr_h(src.buf.reg[5], src.buf.reg[4]));
+ v4i32 t3 = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvr_h(src.buf.reg[7], src.buf.reg[6]));
v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2));
v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2));
// bottom-left 4x4
- v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1],
- src.buf.reg[0]));
- v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3],
- src.buf.reg[2]));
+ v4i32 t4 = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvl_h(src.buf.reg[1], src.buf.reg[0]));
+ v4i32 t5 = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvl_h(src.buf.reg[3], src.buf.reg[2]));
v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4));
v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4));
// bottom-right 4x4
- v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5],
- src.buf.reg[4]));
- v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7],
- src.buf.reg[6]));
+ v4i32 t6 = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvl_h(src.buf.reg[5], src.buf.reg[4]));
+ v4i32 t7 = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvl_h(src.buf.reg[7], src.buf.reg[6]));
v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6));
v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6));
- StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>(
- __builtin_msa_ilvr_d(u2, u0)));
- StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>(
- __builtin_msa_ilvl_d(u2, u0)));
- StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>(
- __builtin_msa_ilvr_d(u3, u1)));
- StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>(
- __builtin_msa_ilvl_d(u3, u1)));
- StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>(
- __builtin_msa_ilvr_d(u6, u4)));
- StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>(
- __builtin_msa_ilvl_d(u6, u4)));
- StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>(
- __builtin_msa_ilvr_d(u7, u5)));
- StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>(
- __builtin_msa_ilvl_d(u7, u5)));
+ StoreInt16x8(dst->data(row + 0, col),
+ reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u2, u0)));
+ StoreInt16x8(dst->data(row + 1, col),
+ reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u2, u0)));
+ StoreInt16x8(dst->data(row + 2, col),
+ reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u3, u1)));
+ StoreInt16x8(dst->data(row + 3, col),
+ reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u3, u1)));
+ StoreInt16x8(dst->data(row + 4, col),
+ reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u6, u4)));
+ StoreInt16x8(dst->data(row + 5, col),
+ reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u6, u4)));
+ StoreInt16x8(dst->data(row + 6, col),
+ reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u7, u5)));
+ StoreInt16x8(dst->data(row + 7, col),
+ reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u7, u5)));
}
}
};
@@ -585,6 +687,391 @@ struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
}
};
+// There's no way to express in C++ the desired machine code for
+// StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> and
+// StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType>.
+// Hence, if we can, we use inline assembly, which takes advantage
+// of little-endian byte order and specifics of different CPU revisions.
+// Note, clang currently can't derive MSA register names from floating-
+// point register names and vice versa in inline assembly.
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) && \
+ !defined(__clang__)
+
+// Instructions for pointer-sized operands.
+#ifdef GEMMLOWP_MIPS_64
+#define GEMMLOWP_MIPS_XADDU "daddu"
+#define GEMMLOWP_MIPS_XLSA "dlsa"
+#else
+#define GEMMLOWP_MIPS_XADDU "addu"
+#define GEMMLOWP_MIPS_XLSA "lsa"
+#endif
+
+// Stores 4 8-byte half-vectors with a stride.
+inline void MipsMsaStore4x8(const RegBlockUint8<8, 4>& src,
+ std::uint8_t* dst_ptr, int stride) {
+#if (__mips_isa_rev >= 6)
+ // Assembly temporaries that will be handily referred to by their names.
+ std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3;
+ v16i8 vtmp0, vtmp1;
+ asm volatile(
+ GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+ "ilvl.d %w[vtmp0], %w[src0], %w[src0]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+ "ilvl.d %w[vtmp1], %w[src1], %w[src1]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+ "sdc1 %[src0], 0(%[dst_ptr0])\n"
+ "sdc1 %[vtmp0], 0(%[dst_ptr1])\n"
+ "sdc1 %[src1], 0(%[dst_ptr2])\n"
+ "sdc1 %[vtmp1], 0(%[dst_ptr3])\n"
+ :
+ // Outputs.
+ [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+ [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+ [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1)
+ :
+ // Inputs.
+ [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+ [stride] "r"(stride)
+ :
+ // Clobbers.
+ "memory");
+#else
+ // Assembly temporaries that will be handily referred to by their names.
+ std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3;
+ int tmp0, tmp1, tmp2, tmp3;
+ asm volatile(
+ GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+ "copy_s.w %[tmp0], %w[src0][0]\n"
+ "copy_s.w %[tmp1], %w[src0][1]\n"
+ "copy_s.w %[tmp2], %w[src0][2]\n"
+ "copy_s.w %[tmp3], %w[src0][3]\n"
+ "swr %[tmp0], 0(%[dst_ptr0])\n"
+ "swl %[tmp0], 3(%[dst_ptr0])\n"
+ "swr %[tmp1], 4(%[dst_ptr0])\n"
+ "swl %[tmp1], 7(%[dst_ptr0])\n"
+ "swr %[tmp2], 0(%[dst_ptr1])\n"
+ "swl %[tmp2], 3(%[dst_ptr1])\n"
+ "swr %[tmp3], 4(%[dst_ptr1])\n"
+ "swl %[tmp3], 7(%[dst_ptr1])\n"
+ "copy_s.w %[tmp0], %w[src1][0]\n"
+ "copy_s.w %[tmp1], %w[src1][1]\n"
+ "copy_s.w %[tmp2], %w[src1][2]\n"
+ "copy_s.w %[tmp3], %w[src1][3]\n"
+ "swr %[tmp0], 0(%[dst_ptr2])\n"
+ "swl %[tmp0], 3(%[dst_ptr2])\n"
+ "swr %[tmp1], 4(%[dst_ptr2])\n"
+ "swl %[tmp1], 7(%[dst_ptr2])\n"
+ "swr %[tmp2], 0(%[dst_ptr3])\n"
+ "swl %[tmp2], 3(%[dst_ptr3])\n"
+ "swr %[tmp3], 4(%[dst_ptr3])\n"
+ "swl %[tmp3], 7(%[dst_ptr3])\n"
+ :
+ // Outputs.
+ [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+ [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), [tmp0] "=&r"(tmp0),
+ [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3)
+ :
+ // Inputs.
+ [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+ [stride] "r"(stride)
+ :
+ // Clobbers.
+ "memory");
+#endif
+}
+
+// Stores 8 4-byte quarter-vectors with a stride.
+inline void MipsMsaStore8x4(const RegBlockUint8<4, 8>& src,
+ std::uint8_t* dst_ptr, int stride) {
+#if (__mips_isa_rev >= 6)
+ // Assembly temporaries that will be handily referred to by their names.
+ std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
+ *dst_ptr6, *dst_ptr7;
+ int tmp1, tmp2, tmp3;
+ asm volatile(
+ GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
+ "copy_s.w %[tmp1], %w[src0][1]\n"
+ "copy_s.w %[tmp2], %w[src0][2]\n"
+ "copy_s.w %[tmp3], %w[src0][3]\n"
+ "swc1 %[src0], 0(%[dst_ptr0])\n"
+ "sw %[tmp1], 0(%[dst_ptr1])\n"
+ "sw %[tmp2], 0(%[dst_ptr2])\n"
+ "sw %[tmp3], 0(%[dst_ptr3])\n"
+ "copy_s.w %[tmp1], %w[src1][1]\n"
+ "copy_s.w %[tmp2], %w[src1][2]\n"
+ "copy_s.w %[tmp3], %w[src1][3]\n"
+ "swc1 %[src1], 0(%[dst_ptr4])\n"
+ "sw %[tmp1], 0(%[dst_ptr5])\n"
+ "sw %[tmp2], 0(%[dst_ptr6])\n"
+ "sw %[tmp3], 0(%[dst_ptr7])\n"
+ :
+ // Outputs.
+ [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+ [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+ [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
+ [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
+ [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3)
+ :
+ // Inputs.
+ [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+ [stride] "r"(stride)
+ :
+ // Clobbers.
+ "memory");
+#else
+ // Assembly temporaries that will be handily referred to by their names.
+ std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
+ *dst_ptr6, *dst_ptr7;
+ int tmp0, tmp1, tmp2, tmp3;
+ asm volatile(
+ GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
+ "copy_s.w %[tmp0], %w[src0][0]\n"
+ "copy_s.w %[tmp1], %w[src0][1]\n"
+ "copy_s.w %[tmp2], %w[src0][2]\n"
+ "copy_s.w %[tmp3], %w[src0][3]\n"
+ "swr %[tmp0], 0(%[dst_ptr0])\n"
+ "swl %[tmp0], 3(%[dst_ptr0])\n"
+ "swr %[tmp1], 0(%[dst_ptr1])\n"
+ "swl %[tmp1], 3(%[dst_ptr1])\n"
+ "swr %[tmp2], 0(%[dst_ptr2])\n"
+ "swl %[tmp2], 3(%[dst_ptr2])\n"
+ "swr %[tmp3], 0(%[dst_ptr3])\n"
+ "swl %[tmp3], 3(%[dst_ptr3])\n"
+ "copy_s.w %[tmp0], %w[src1][0]\n"
+ "copy_s.w %[tmp1], %w[src1][1]\n"
+ "copy_s.w %[tmp2], %w[src1][2]\n"
+ "copy_s.w %[tmp3], %w[src1][3]\n"
+ "swr %[tmp0], 0(%[dst_ptr4])\n"
+ "swl %[tmp0], 3(%[dst_ptr4])\n"
+ "swr %[tmp1], 0(%[dst_ptr5])\n"
+ "swl %[tmp1], 3(%[dst_ptr5])\n"
+ "swr %[tmp2], 0(%[dst_ptr6])\n"
+ "swl %[tmp2], 3(%[dst_ptr6])\n"
+ "swr %[tmp3], 0(%[dst_ptr7])\n"
+ "swl %[tmp3], 3(%[dst_ptr7])\n"
+ :
+ // Outputs.
+ [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+ [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+ [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
+ [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
+ [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2),
+ [tmp3] "=&r"(tmp3)
+ :
+ // Inputs.
+ [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+ [stride] "r"(stride)
+ :
+ // Clobbers.
+ "memory");
+#endif
+}
+
+// Stores 8 8-byte half-vectors with a stride.
+inline void MipsMsaStore8x8(const RegBlockUint8<8, 8>& src,
+ std::uint8_t* dst_ptr, int stride) {
+#if (__mips_isa_rev >= 6)
+ // Assembly temporaries that will be handily referred to by their names.
+ std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
+ *dst_ptr6, *dst_ptr7;
+ v16i8 vtmp0, vtmp1, vtmp2, vtmp3;
+ asm volatile(
+ "ilvl.d %w[vtmp0], %w[src0], %w[src0]\n"
+ GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+ "ilvl.d %w[vtmp1], %w[src1], %w[src1]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+ "ilvl.d %w[vtmp2], %w[src2], %w[src2]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
+ "ilvl.d %w[vtmp3], %w[src3], %w[src3]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
+ "sdc1 %[src0], 0(%[dst_ptr0])\n"
+ "sdc1 %[vtmp0], 0(%[dst_ptr1])\n"
+ "sdc1 %[src1], 0(%[dst_ptr2])\n"
+ "sdc1 %[vtmp1], 0(%[dst_ptr3])\n"
+ "sdc1 %[src2], 0(%[dst_ptr4])\n"
+ "sdc1 %[vtmp2], 0(%[dst_ptr5])\n"
+ "sdc1 %[src3], 0(%[dst_ptr6])\n"
+ "sdc1 %[vtmp3], 0(%[dst_ptr7])\n"
+ :
+ // Outputs.
+ [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+ [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+ [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
+ [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
+ [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1), [vtmp2] "=&f"(vtmp2),
+ [vtmp3] "=&f"(vtmp3)
+ :
+ // Inputs.
+ [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+ [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]),
+ [stride] "r"(stride)
+ :
+ // Clobbers.
+ "memory");
+#else
+ // Assembly temporaries that will be handily referred to by their names.
+ std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5,
+ *dst_ptr6, *dst_ptr7;
+ int tmp0, tmp1, tmp2, tmp3;
+ asm volatile(
+ GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n"
+ GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n"
+ "copy_s.w %[tmp0], %w[src0][0]\n"
+ "copy_s.w %[tmp1], %w[src0][1]\n"
+ "copy_s.w %[tmp2], %w[src0][2]\n"
+ "copy_s.w %[tmp3], %w[src0][3]\n"
+ "swr %[tmp0], 0(%[dst_ptr0])\n"
+ "swl %[tmp0], 3(%[dst_ptr0])\n"
+ "swr %[tmp1], 4(%[dst_ptr0])\n"
+ "swl %[tmp1], 7(%[dst_ptr0])\n"
+ "swr %[tmp2], 0(%[dst_ptr1])\n"
+ "swl %[tmp2], 3(%[dst_ptr1])\n"
+ "swr %[tmp3], 4(%[dst_ptr1])\n"
+ "swl %[tmp3], 7(%[dst_ptr1])\n"
+ "copy_s.w %[tmp0], %w[src1][0]\n"
+ "copy_s.w %[tmp1], %w[src1][1]\n"
+ "copy_s.w %[tmp2], %w[src1][2]\n"
+ "copy_s.w %[tmp3], %w[src1][3]\n"
+ "swr %[tmp0], 0(%[dst_ptr2])\n"
+ "swl %[tmp0], 3(%[dst_ptr2])\n"
+ "swr %[tmp1], 4(%[dst_ptr2])\n"
+ "swl %[tmp1], 7(%[dst_ptr2])\n"
+ "swr %[tmp2], 0(%[dst_ptr3])\n"
+ "swl %[tmp2], 3(%[dst_ptr3])\n"
+ "swr %[tmp3], 4(%[dst_ptr3])\n"
+ "swl %[tmp3], 7(%[dst_ptr3])\n"
+ "copy_s.w %[tmp0], %w[src2][0]\n"
+ "copy_s.w %[tmp1], %w[src2][1]\n"
+ "copy_s.w %[tmp2], %w[src2][2]\n"
+ "copy_s.w %[tmp3], %w[src2][3]\n"
+ "swr %[tmp0], 0(%[dst_ptr4])\n"
+ "swl %[tmp0], 3(%[dst_ptr4])\n"
+ "swr %[tmp1], 4(%[dst_ptr4])\n"
+ "swl %[tmp1], 7(%[dst_ptr4])\n"
+ "swr %[tmp2], 0(%[dst_ptr5])\n"
+ "swl %[tmp2], 3(%[dst_ptr5])\n"
+ "swr %[tmp3], 4(%[dst_ptr5])\n"
+ "swl %[tmp3], 7(%[dst_ptr5])\n"
+ "copy_s.w %[tmp0], %w[src3][0]\n"
+ "copy_s.w %[tmp1], %w[src3][1]\n"
+ "copy_s.w %[tmp2], %w[src3][2]\n"
+ "copy_s.w %[tmp3], %w[src3][3]\n"
+ "swr %[tmp0], 0(%[dst_ptr6])\n"
+ "swl %[tmp0], 3(%[dst_ptr6])\n"
+ "swr %[tmp1], 4(%[dst_ptr6])\n"
+ "swl %[tmp1], 7(%[dst_ptr6])\n"
+ "swr %[tmp2], 0(%[dst_ptr7])\n"
+ "swl %[tmp2], 3(%[dst_ptr7])\n"
+ "swr %[tmp3], 4(%[dst_ptr7])\n"
+ "swl %[tmp3], 7(%[dst_ptr7])\n"
+ :
+ // Outputs.
+ [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1),
+ [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3),
+ [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5),
+ [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7),
+ [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2),
+ [tmp3] "=&r"(tmp3)
+ :
+ // Inputs.
+ [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]),
+ [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]),
+ [stride] "r"(stride)
+ :
+ // Clobbers.
+ "memory");
+#endif
+}
+
+#undef GEMMLOWP_MIPS_XADDU
+#undef GEMMLOWP_MIPS_XLSA
+
+// Transposes a column-major 8x4 block for storage into a row-major matrix.
+inline RegBlockUint8<4, 8> Transpose(const RegBlockUint8<8, 4>& src) {
+ v16i8 tmp0 = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]);
+ v16i8 tmp1 = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]);
+ RegBlockUint8<4, 8> result;
+ result.buf.reg[0] = __builtin_msa_ilvr_b(tmp1, tmp0);
+ result.buf.reg[1] = __builtin_msa_ilvl_b(tmp1, tmp0);
+ return result;
+}
+
+inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) {
+ v16i8 tmp0[4];
+ tmp0[0] = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]);
+ tmp0[1] = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]);
+ tmp0[2] = __builtin_msa_ilvr_b(src.buf.reg[3], src.buf.reg[2]);
+ tmp0[3] = __builtin_msa_ilvl_b(src.buf.reg[3], src.buf.reg[2]);
+ v16i8 tmp1[4];
+ tmp1[0] = __builtin_msa_ilvr_b(tmp0[1], tmp0[0]);
+ tmp1[1] = __builtin_msa_ilvl_b(tmp0[1], tmp0[0]);
+ tmp1[2] = __builtin_msa_ilvr_b(tmp0[3], tmp0[2]);
+ tmp1[3] = __builtin_msa_ilvl_b(tmp0[3], tmp0[2]);
+ RegBlockUint8<8, 8> result;
+ result.buf.reg[0] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w(
+ reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0])));
+ result.buf.reg[1] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w(
+ reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0])));
+ result.buf.reg[2] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w(
+ reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1])));
+ result.buf.reg[3] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w(
+ reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1])));
+ return result;
+}
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
+ static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ std::uint8_t* dst_ptr = dst->data(row, col);
+ int col_stride = dst->cols_stride();
+ MipsMsaStore4x8(src, dst_ptr, col_stride);
+ } else {
+ const auto& block = Transpose(src);
+ std::uint8_t* dst_ptr = dst->data(row, col);
+ int row_stride = dst->rows_stride();
+ MipsMsaStore8x4(block, dst_ptr, row_stride);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
+ static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
+ int col) {
+ const auto& block =
+ (DstType::kOrder == MapOrder::ColMajor) ? src : Transpose(src);
+ std::uint8_t* dst_ptr = dst->data(row, col);
+ int stride = dst->stride();
+ MipsMsaStore8x8(block, dst_ptr, stride);
+ }
+};
+
+#else
+
template <typename DstType>
struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
@@ -617,6 +1104,8 @@ struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
}
};
+#endif // Endianness, compiler.
+
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
diff --git a/internal/output_neon.h b/internal/output_neon.h
index 911fed0..52ea1bc 100644
--- a/internal/output_neon.h
+++ b/internal/output_neon.h
@@ -108,6 +108,90 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
};
template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferInt8<4> OutputType;
+
+ typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ int16x4_t res_16 = vqmovn_s32(input.reg[0]);
+ int8x8_t res_8 = vqmovn_s16(vcombine_s16(res_16, res_16));
+ output.reg[0] = vget_lane_s32(vreinterpret_s32_s8(res_8), 0);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferInt8<8> OutputType;
+
+ typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ int16x8_t res_16 =
+ vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+ output.reg[0] = vqmovn_s16(res_16);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferInt8<16> OutputType;
+
+ typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ int16x8_t res_16_0 =
+ vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+ int16x8_t res_16_1 =
+ vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
+ output.reg[0] = vqmovn_s16(res_16_0);
+ output.reg[1] = vqmovn_s16(res_16_1);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferInt8<32> OutputType;
+
+ typedef OutputStageSaturatingCastToInt8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ int16x8_t res_16[4];
+ for (int i = 0; i < 4; i++) {
+ res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]),
+ vqmovn_s32(input.reg[2 * i + 1]));
+ }
+ for (int i = 0; i < 4; i++) {
+ output.reg[i] = vqmovn_s16(res_16[i]);
+ }
+ return output;
+ }
+};
+
+template <>
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
RegBufferInt32<4>> {
typedef RegBufferInt32<4> InputType;
@@ -556,8 +640,8 @@ struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]);
}
} else {
+ int row_stride = dst->rows_stride();
for (int i = 0; i < 4; i++) {
- int row_stride = dst->rows_stride();
std::uint8_t* col_ptr = dst_ptr + i;
vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
@@ -623,6 +707,153 @@ struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<4, 1>, DstType> {
+ static void Run(const RegBlockInt8<4, 1>& src, DstType* dst, int row,
+ int col) {
+ const std::int32_t src_reg = src.buf.reg[0];
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row + i, col) = (src_reg >> (8 * i));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<1, 4>, DstType> {
+ static void Run(const RegBlockInt8<1, 4>& src, DstType* dst, int row,
+ int col) {
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<8, 1>, DstType> {
+ static void Run(const RegBlockInt8<8, 1>& src, DstType* dst, int row,
+ int col) {
+ std::int8_t* dst_ptr = dst->data(row, col);
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ vst1_s8(dst_ptr, src.buf.reg[0]);
+ } else {
+ const int row_stride = dst->rows_stride();
+ vst1_lane_s8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
+ vst1_lane_s8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
+ vst1_lane_s8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
+ vst1_lane_s8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
+ vst1_lane_s8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4);
+ vst1_lane_s8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5);
+ vst1_lane_s8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6);
+ vst1_lane_s8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<4, 4>, DstType> {
+ static void Run(const RegBlockInt8<4, 4>& src, DstType* dst, int row,
+ int col) {
+ std::int8_t* dst_ptr = dst->data(row, col);
+ const int row_stride = dst->rows_stride();
+ const int col_stride = dst->cols_stride();
+ for (int i = 0; i < 2; i++) {
+ vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride,
+ src.buf.reg[i], 0);
+ vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride,
+ src.buf.reg[i], 1);
+ vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride,
+ src.buf.reg[i], 2);
+ vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride,
+ src.buf.reg[i], 3);
+ vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride,
+ src.buf.reg[i], 4);
+ vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride,
+ src.buf.reg[i], 5);
+ vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride,
+ src.buf.reg[i], 6);
+ vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride,
+ src.buf.reg[i], 7);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<8, 4>, DstType> {
+ static void Run(const RegBlockInt8<8, 4>& src, DstType* dst, int row,
+ int col) {
+ std::int8_t* dst_ptr = dst->data(row, col);
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ int col_stride = dst->cols_stride();
+ for (int i = 0; i < 4; i++) {
+ vst1_s8(dst_ptr + i * col_stride, src.buf.reg[i]);
+ }
+ } else {
+ int row_stride = dst->rows_stride();
+ for (int i = 0; i < 4; i++) {
+ std::int8_t* col_ptr = dst_ptr + i;
+ vst1_lane_s8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
+ vst1_lane_s8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
+ vst1_lane_s8(col_ptr + 2 * row_stride, src.buf.reg[i], 2);
+ vst1_lane_s8(col_ptr + 3 * row_stride, src.buf.reg[i], 3);
+ vst1_lane_s8(col_ptr + 4 * row_stride, src.buf.reg[i], 4);
+ vst1_lane_s8(col_ptr + 5 * row_stride, src.buf.reg[i], 5);
+ vst1_lane_s8(col_ptr + 6 * row_stride, src.buf.reg[i], 6);
+ vst1_lane_s8(col_ptr + 7 * row_stride, src.buf.reg[i], 7);
+ }
+ }
+ }
+};
+
+inline RegBlockInt8<8, 8> Transpose(const RegBlockInt8<8, 8>& src) {
+ int8x8x2_t a[4];
+ a[0] = vtrn_s8(src.buf.reg[0], src.buf.reg[1]);
+ a[1] = vtrn_s8(src.buf.reg[2], src.buf.reg[3]);
+ a[2] = vtrn_s8(src.buf.reg[4], src.buf.reg[5]);
+ a[3] = vtrn_s8(src.buf.reg[6], src.buf.reg[7]);
+ int16x4x2_t b[4];
+ b[0] = vtrn_s16(vreinterpret_s16_s8(a[0].val[0]),
+ vreinterpret_s16_s8(a[1].val[0]));
+ b[1] = vtrn_s16(vreinterpret_s16_s8(a[0].val[1]),
+ vreinterpret_s16_s8(a[1].val[1]));
+ b[2] = vtrn_s16(vreinterpret_s16_s8(a[2].val[0]),
+ vreinterpret_s16_s8(a[3].val[0]));
+ b[3] = vtrn_s16(vreinterpret_s16_s8(a[2].val[1]),
+ vreinterpret_s16_s8(a[3].val[1]));
+ int32x2x2_t c[4];
+ c[0] = vtrn_s32(vreinterpret_s32_s16(b[0].val[0]),
+ vreinterpret_s32_s16(b[2].val[0]));
+ c[1] = vtrn_s32(vreinterpret_s32_s16(b[1].val[0]),
+ vreinterpret_s32_s16(b[3].val[0]));
+ c[2] = vtrn_s32(vreinterpret_s32_s16(b[0].val[1]),
+ vreinterpret_s32_s16(b[2].val[1]));
+ c[3] = vtrn_s32(vreinterpret_s32_s16(b[1].val[1]),
+ vreinterpret_s32_s16(b[3].val[1]));
+ RegBlockInt8<8, 8> result;
+ result.buf.reg[0] = vreinterpret_s8_s32(c[0].val[0]);
+ result.buf.reg[1] = vreinterpret_s8_s32(c[1].val[0]);
+ result.buf.reg[2] = vreinterpret_s8_s32(c[2].val[0]);
+ result.buf.reg[3] = vreinterpret_s8_s32(c[3].val[0]);
+ result.buf.reg[4] = vreinterpret_s8_s32(c[0].val[1]);
+ result.buf.reg[5] = vreinterpret_s8_s32(c[1].val[1]);
+ result.buf.reg[6] = vreinterpret_s8_s32(c[2].val[1]);
+ result.buf.reg[7] = vreinterpret_s8_s32(c[3].val[1]);
+ return result;
+}
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt8<8, 8>, DstType> {
+ static void Run(const RegBlockInt8<8, 8>& src, DstType* dst, int row,
+ int col) {
+ const auto& block =
+ DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
+ std::int8_t* dst_ptr = dst->data(row, col);
+ int stride = dst->stride();
+ for (int i = 0; i < 8; i++) {
+ vst1_s8(dst_ptr + i * stride, block.buf.reg[i]);
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
int col) {
diff --git a/internal/pack.h b/internal/pack.h
index cb4b93a..7c43d6e 100644
--- a/internal/pack.h
+++ b/internal/pack.h
@@ -72,6 +72,10 @@ class PackedSideBlock {
pos_ += n * KernelSideFormat::Cell::kSize;
}
+ // TODO(suharshs): The datatype can now be int8 as well. We could introduce a
+ // new int8 current_data impl as well. This change would propagate to all pack
+ // impls and the Kernel::Run API, which all assume uint8. For now we leave
+ // this as-is pending future refactor.
const std::uint8_t* current_data() const {
return allocator_->GetPointer<std::uint8_t>(data_handle_) + pos_;
}
@@ -208,6 +212,7 @@ class PackingRegisterBlockBase {
public:
typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat;
typedef typename KernelSideFormat::Cell CellFormat;
+ typedef typename KernelSideFormat::InputScalar KernelInputScalar;
typedef typename KernelSideFormat::Scalar KernelScalar;
static const int kCells = KernelSideFormat::kCells;
static const int kCellWidth = CellFormat::kWidth;
@@ -216,7 +221,7 @@ class PackingRegisterBlockBase {
static const int kCellSize = CellFormat::kSize;
static const SideMapOrder kSrcOrder = SrcMapType::kOrder;
static const int kZeroPointInputValue =
- ZeroPointInputValue<KernelScalar>::kValue;
+ ZeroPointInputValue<KernelInputScalar, KernelScalar>::kValue;
PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {}
@@ -233,7 +238,7 @@ class PackingRegisterBlockBase {
std::uint8_t buf_[kKernelWidth * kRegisterSize];
public:
- // Selects a block if in-place source data that's already a complete block
+ // Selects a block if in-place source data that's already a complete block.
void UseCompleteSrcInPlace(const SrcMapType& src) { complete_src_ = src; }
// Copies an incomplete block of source data into a local temporary
// complete block by zero-extending it.
@@ -249,7 +254,10 @@ class PackingRegisterBlockBase {
memcpy(buf_ + d * kKernelWidth, src.data(0, d), src.width());
}
}
- complete_src_ = SrcMapType(buf_, kKernelWidth, kRegisterSize);
+
+ // Since the KernelInputScalar type may not be uint8, we need to cast buf_.
+ complete_src_ = SrcMapType(reinterpret_cast<KernelInputScalar*>(buf_),
+ kKernelWidth, kRegisterSize);
}
// Packs a complete block into the destination. This is the most
// critical part and the part that we most typically want to
@@ -340,7 +348,7 @@ class PackSideBlockImpl {
}
}
- // Prefetches the data that will be read by PackL1
+ // Prefetches the data that will be read by PackL1.
void PrefetchL1(int start_width, int width, int start_depth, int depth) {
if (SrcMapType::kOrder == SideMapOrder::WidthMajor) {
for (int d = 0; d < depth; d += kDefaultCacheLineSize) {
@@ -394,7 +402,7 @@ class PackSideBlockImpl {
const SrcMapType& src_map_;
};
-// Packs a block of the input LHS matrix, into a PackedSideBlock
+// Packs a block of the input LHS matrix, into a PackedSideBlock.
template <typename PackedSideBlock, typename MatrixMapType>
void PackLhs(PackedSideBlock* dst, const MatrixMapType& src) {
ScopedProfilingLabel label("pack LHS");
@@ -409,7 +417,7 @@ void PackLhs(PackedSideBlock* dst, const MatrixMapType& src) {
impl.PackL2();
}
-// Packs a block of the input RHS matrix, into a PackedSideBlock
+// Packs a block of the input RHS matrix, into a PackedSideBlock.
template <typename PackedSideBlock, typename MatrixMapType>
void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) {
ScopedProfilingLabel label("pack RHS");
@@ -430,6 +438,8 @@ void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) {
#include "pack_neon.h"
#elif defined(GEMMLOWP_SSE4)
#include "pack_sse.h"
+#elif defined(GEMMLOWP_AVX2)
+#include "pack_avx.h"
#elif defined(GEMMLOWP_MSA)
#include "pack_msa.h"
#endif
diff --git a/internal/pack_avx.h b/internal/pack_avx.h
new file mode 100644
index 0000000..1ef5ce1
--- /dev/null
+++ b/internal/pack_avx.h
@@ -0,0 +1,282 @@
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// pack_avx.h: optimized AVX specializations of the templates in pack.h.
+
+#ifndef GEMMLOWP_INTERNAL_PACK_AVX_H_
+#define GEMMLOWP_INTERNAL_PACK_AVX_H_
+
+#include <immintrin.h>
+#include "pack.h"
+
+namespace gemmlowp {
+
+// TODO: Add DepthMajorUint8SideMap
+
+typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
+ WidthMajorUint8SideMap;
+
+template <int Cells>
+using WidthMajorSideFormatNCells4x2 =
+ KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, Cells>;
+
+template <int Cells>
+class PackingRegisterBlock<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
+ : public PackingRegisterBlockBase<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
+ public:
+ typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
+ typedef typename KernelSideFormat::Cell CellFormat;
+ static const int kCells = KernelSideFormat::kCells;
+ static const int kCellWidth = CellFormat::kWidth;
+ static const int kKernelWidth = CellFormat::kWidth * kCells;
+ static const int kCellDepth = CellFormat::kDepth;
+ static const int kCellSize = CellFormat::kSize;
+
+ void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
+ std::uint8_t *dst_ptr = dst->current_data();
+ const int width_stride = this->complete_src_.width_stride();
+ int depth_step = 16;
+
+ __m256i one = _mm256_set1_epi16(1);
+ for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
+ cell_start_depth += depth_step) {
+ for (int cell_start_width = 0; cell_start_width < kKernelWidth;
+ cell_start_width += kCellWidth) {
+ std::int32_t *cell_sums_of_each_slice_ptr =
+ dst->sums_of_each_slice() + start_width + cell_start_width;
+ const std::uint8_t *src_data =
+ this->complete_src_.data(cell_start_width, cell_start_depth);
+
+ __m128i xmm1 =
+ _mm_loadu_si128(reinterpret_cast<const __m128i *>(&src_data[0]));
+ __m128i xmm2 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
+ __m128i xmm3 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
+ __m128i xmm4 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
+ __m128i xmm5 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i *>(&src_data[4 * width_stride]));
+ __m128i xmm6 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i *>(&src_data[5 * width_stride]));
+ __m128i xmm7 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i *>(&src_data[6 * width_stride]));
+ __m128i xmm8 = _mm_loadu_si128(
+ reinterpret_cast<const __m128i *>(&src_data[7 * width_stride]));
+
+ __m256i ymm1 = _mm256_set_m128i(xmm5, xmm1);
+ __m256i ymm2 = _mm256_set_m128i(xmm6, xmm2);
+ __m256i ymm3 = _mm256_set_m128i(xmm7, xmm3);
+ __m256i ymm4 = _mm256_set_m128i(xmm8, xmm4);
+
+ __m256i ymm5 = _mm256_unpacklo_epi16(ymm1, ymm2);
+ __m256i ymm6 = _mm256_unpacklo_epi16(ymm3, ymm4);
+
+ __m256i ymm9 = _mm256_unpackhi_epi16(ymm1, ymm2);
+ __m256i ymm10 = _mm256_unpackhi_epi16(ymm3, ymm4);
+
+ __m256i ymm7 = _mm256_unpacklo_epi32(ymm5, ymm6);
+ __m256i ymm8 = _mm256_unpackhi_epi32(ymm5, ymm6);
+
+ __m256i ymm13 = _mm256_unpacklo_epi32(ymm9, ymm10);
+ __m256i ymm14 = _mm256_unpackhi_epi32(ymm9, ymm10);
+
+ __m256i ymm11 = _mm256_permute4x64_epi64(ymm7, 0xd8);
+ __m256i ymm12 = _mm256_permute4x64_epi64(ymm8, 0xd8);
+
+ __m256i ymm15 = _mm256_permute4x64_epi64(ymm13, 0xd8);
+ __m256i ymm16 = _mm256_permute4x64_epi64(ymm14, 0xd8);
+
+ __m128i xmm9 = _mm256_castsi256_si128(ymm11);
+ __m128i xmm10 = _mm256_castsi256_si128(ymm12);
+ __m128i xmm11 = _mm256_extracti128_si256(ymm11, 1);
+ __m128i xmm12 = _mm256_extracti128_si256(ymm12, 1);
+
+ xmm1 = _mm256_castsi256_si128(ymm15);
+ xmm2 = _mm256_castsi256_si128(ymm16);
+ xmm3 = _mm256_extracti128_si256(ymm15, 1);
+ xmm4 = _mm256_extracti128_si256(ymm16, 1);
+
+ _mm_storeu_si128(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm11);
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
+ xmm10);
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
+ xmm12);
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i *>(&dst_ptr[4 * kCellSize * kCells]),
+ xmm1);
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i *>(&dst_ptr[5 * kCellSize * kCells]),
+ xmm3);
+
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i *>(&dst_ptr[6 * kCellSize * kCells]),
+ xmm2);
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i *>(&dst_ptr[7 * kCellSize * kCells]),
+ xmm4);
+
+ ymm6 = _mm256_cvtepu8_epi16(xmm9);
+ ymm7 = _mm256_madd_epi16(ymm6, one);
+ __m256i sums_of_each_slice_xmm = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i *>(&cell_sums_of_each_slice_ptr[0]));
+ sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+ ymm6 = _mm256_cvtepu8_epi16(xmm11);
+ ymm7 = _mm256_madd_epi16(ymm6, one);
+ sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+ ymm6 = _mm256_cvtepu8_epi16(xmm10);
+ ymm7 = _mm256_madd_epi16(ymm6, one);
+ sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+ ymm6 = _mm256_cvtepu8_epi16(xmm12);
+ ymm7 = _mm256_madd_epi16(ymm6, one);
+ sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+ ymm6 = _mm256_cvtepu8_epi16(xmm1);
+ ymm7 = _mm256_madd_epi16(ymm6, one);
+ sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+ ymm6 = _mm256_cvtepu8_epi16(xmm3);
+ ymm7 = _mm256_madd_epi16(ymm6, one);
+ sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+ ymm6 = _mm256_cvtepu8_epi16(xmm2);
+ ymm7 = _mm256_madd_epi16(ymm6, one);
+ sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+ ymm6 = _mm256_cvtepu8_epi16(xmm4);
+ ymm7 = _mm256_madd_epi16(ymm6, one);
+ sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7);
+
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i *>(&cell_sums_of_each_slice_ptr[0]),
+ sums_of_each_slice_xmm);
+ dst_ptr += kCellSize;
+ }
+ dst_ptr += 7 * kCellSize * kCells;
+ }
+ dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
+ }
+};
+
+// Pack format for 4x2 rhs format
+template <int Cells>
+using RhsWidthMajorSideFormatNCells4x2 =
+ KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
+
+template <int Cells>
+class PackingRegisterBlock<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>>
+ : public PackingRegisterBlockBase<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> {
+ public:
+ typedef RhsWidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
+ typedef typename KernelSideFormat::Cell CellFormat;
+ static const int kCells = KernelSideFormat::kCells;
+ static const int kCellWidth = CellFormat::kWidth;
+ static const int kKernelWidth = CellFormat::kWidth * kCells;
+ static const int kCellDepth = CellFormat::kDepth;
+ static const int kCellSize = CellFormat::kSize;
+
+ void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) {
+ std::uint8_t *dst_ptr = dst->current_data();
+ const int width_stride = this->complete_src_.width_stride();
+ int depth_step = 8;
+
+ __m128i one = _mm_set1_epi16(1);
+ for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
+ cell_start_depth += depth_step) {
+ for (int cell_start_width = 0; cell_start_width < kKernelWidth;
+ cell_start_width += kCellWidth) {
+ std::int32_t *cell_sums_of_each_slice_ptr =
+ dst->sums_of_each_slice() + start_width + cell_start_width;
+ const std::uint8_t *src_data =
+ this->complete_src_.data(cell_start_width, cell_start_depth);
+
+ __m128i xmm1 =
+ _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&src_data[0]));
+ __m128i xmm2 = _mm_loadl_epi64(
+ reinterpret_cast<const __m128i *>(&src_data[1 * width_stride]));
+ __m128i xmm3 = _mm_loadl_epi64(
+ reinterpret_cast<const __m128i *>(&src_data[2 * width_stride]));
+ __m128i xmm4 = _mm_loadl_epi64(
+ reinterpret_cast<const __m128i *>(&src_data[3 * width_stride]));
+
+ __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
+ __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
+
+ __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
+ __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
+
+ __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
+ __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
+
+ _mm_storel_epi64(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm10);
+
+ __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
+ __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
+
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]),
+ xmm11);
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]),
+ xmm12);
+
+ xmm1 = _mm_cvtepu8_epi16(xmm9);
+ xmm2 = _mm_madd_epi16(xmm1, one);
+ __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
+ reinterpret_cast<const __m128i *>(&cell_sums_of_each_slice_ptr[0]));
+ sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+ xmm1 = _mm_cvtepu8_epi16(xmm10);
+ xmm2 = _mm_madd_epi16(xmm1, one);
+ sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+ xmm1 = _mm_cvtepu8_epi16(xmm11);
+ xmm2 = _mm_madd_epi16(xmm1, one);
+ sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+ xmm1 = _mm_cvtepu8_epi16(xmm12);
+ xmm2 = _mm_madd_epi16(xmm1, one);
+ sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
+
+ _mm_storeu_si128(
+ reinterpret_cast<__m128i *>(&cell_sums_of_each_slice_ptr[0]),
+ sums_of_each_slice_xmm);
+ dst_ptr += kCellSize;
+ }
+ dst_ptr += 3 * kCellSize * kCells;
+ }
+ dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
+ }
+};
+
+} // namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_PACK_AVX_H_
diff --git a/internal/pack_msa.h b/internal/pack_msa.h
index fba8a0f..4072229 100644
--- a/internal/pack_msa.h
+++ b/internal/pack_msa.h
@@ -348,6 +348,84 @@ class PackingRegisterBlock<
}
};
+template <int Width>
+using Int8FastKernelFormat =
+ KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
+
+template <int Width>
+class PackingRegisterBlock<WidthMajorUint8SideMap,
+ PackedSideBlock<Int8FastKernelFormat<Width>>>
+ : public PackingRegisterBlockBase<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<Int8FastKernelFormat<Width>>> {
+ public:
+ static_assert(Width == 2 || Width == 4, "");
+ typedef Int8FastKernelFormat<Width> KernelSideFormat;
+ typedef typename KernelSideFormat::Cell CellFormat;
+ static const int kCells = KernelSideFormat::kCells;
+ static const int kCellWidth = CellFormat::kWidth;
+ static const int kKernelWidth = CellFormat::kWidth * kCells;
+ static const int kCellDepth = CellFormat::kDepth;
+ static const int kCellSize = CellFormat::kSize;
+
+ void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+ std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
+ std::uint8_t* dst_ptr = dst->current_data();
+ const std::uint8_t* const src_ptr = this->complete_src_.data();
+ const int stride = this->complete_src_.stride();
+ // Load source WidthMajor data.
+ v16i8 src_lines[Width];
+ for (int i = 0; i < Width; i++) {
+ src_lines[i] = __builtin_msa_ld_b(
+ const_cast<std::uint8_t*>(src_ptr + i * stride), 0);
+ }
+ for (int i = 0; i < Width; i++) {
+ // Subtract 128 by inverting bit 7.
+ src_lines[i] = reinterpret_cast<v16i8>(
+ __builtin_msa_bnegi_b(reinterpret_cast<v16u8>(src_lines[i]), 7));
+ }
+ for (int i = 0; i < Width; i++) {
+ __builtin_msa_st_b(src_lines[i], dst_ptr + 16 * i, 0);
+ }
+ v8i16 sums2[Width];
+ for (int i = 0; i < Width; i++) {
+ sums2[i] = __builtin_msa_hadd_s_h(src_lines[i], src_lines[i]);
+ }
+ v4i32 sums4_wide[Width];
+ for (int i = 0; i < Width; i++) {
+ sums4_wide[i] = __builtin_msa_hadd_s_w(sums2[i], sums2[i]);
+ }
+ v8i16 sums4[Width / 2];
+ for (int i = 0; i < Width / 2; i++) {
+ sums4[i] = __builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(sums4_wide[2 * i + 1]),
+ reinterpret_cast<v8i16>(sums4_wide[2 * i]));
+ }
+ v4i32 sums8_wide[Width / 2];
+ for (int i = 0; i < Width / 2; i++) {
+ sums8_wide[i] = __builtin_msa_hadd_s_w(sums4[i], sums4[i]);
+ }
+ if (Width == 4) {
+ v4i32 sum = __builtin_msa_ld_w(const_cast<std::int32_t*>(sums_ptr), 0);
+ v8i16 sums8 = __builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(sums8_wide[1]),
+ reinterpret_cast<v8i16>(sums8_wide[0]));
+ v4i32 sums16 = __builtin_msa_hadd_s_w(sums8, sums8);
+ sum = __builtin_msa_addv_w(sum, sums16);
+ __builtin_msa_st_w(sum, sums_ptr, 0);
+ } else {
+ assert(Width == 2);
+ std::int32_t sum[2] = { sums_ptr[0], sums_ptr[1] };
+ v2i64 sums16 = __builtin_msa_hadd_s_d(sums8_wide[0], sums8_wide[0]);
+ sum[0] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 0);
+ sum[1] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 2);
+ sums_ptr[0] = sum[0];
+ sums_ptr[1] = sum[1];
+ }
+ dst->seek_forward_n_cells(1);
+ }
+};
+
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_PACK_MSA_H_
diff --git a/internal/pack_neon.h b/internal/pack_neon.h
index 2b08464..f113d9e 100644
--- a/internal/pack_neon.h
+++ b/internal/pack_neon.h
@@ -26,6 +26,9 @@ namespace gemmlowp {
typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
WidthMajorUint8SideMap;
+typedef SideMap<const std::int8_t, SideMapOrder::WidthMajor>
+ WidthMajorInt8SideMap;
+
template <int Cells>
using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
@@ -315,6 +318,67 @@ class PackingRegisterBlock<WidthMajorUint8SideMap,
}
};
+template <int Width>
+using Int8InputsFastKernelFormat =
+ KernelSideFormatInt8Inputs<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
+
+// Same as above, but for int8 inputs, avoiding the uint8 -> int8 conversion.
+template <int Width>
+class PackingRegisterBlock<WidthMajorInt8SideMap,
+ PackedSideBlock<Int8InputsFastKernelFormat<Width>>>
+ : public PackingRegisterBlockBase<
+ WidthMajorInt8SideMap,
+ PackedSideBlock<Int8InputsFastKernelFormat<Width>>> {
+ public:
+ static_assert(Width == 2 || Width == 4, "");
+ typedef Int8InputsFastKernelFormat<Width> KernelSideFormat;
+ typedef typename KernelSideFormat::Cell CellFormat;
+ static const int kCells = KernelSideFormat::kCells;
+ static const int kCellWidth = CellFormat::kWidth;
+ static const int kKernelWidth = CellFormat::kWidth * kCells;
+ static const int kCellDepth = CellFormat::kDepth;
+ static const int kCellSize = CellFormat::kSize;
+
+ void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+ std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
+ std::int8_t* dst_ptr = reinterpret_cast<std::int8_t*>(dst->current_data());
+ const std::int8_t* const src_ptr = this->complete_src_.data();
+ const int stride = this->complete_src_.stride();
+ // Load source WidthMajor data
+ int8x16_t src_lines[Width];
+ for (int i = 0; i < Width; i++) {
+ src_lines[i] = vld1q_s8(src_ptr + i * stride);
+ }
+ for (int i = 0; i < Width; i++) {
+ vst1q_s8(dst_ptr + 16 * i, src_lines[i]);
+ }
+ int16x8_t sums2[Width];
+ for (int i = 0; i < Width; i++) {
+ const int8x8_t lo = vget_low_s8(src_lines[i]);
+ const int8x8_t hi = vget_high_s8(src_lines[i]);
+ sums2[i] = vaddl_s8(lo, hi);
+ }
+ int16x8_t sums4[Width / 2];
+ for (int i = 0; i < Width / 2; i++) {
+ sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]);
+ }
+ if (Width == 4) {
+ int32x4_t sum = vld1q_s32(sums_ptr);
+ int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]);
+ sum = vpadalq_s16(sum, sums8);
+ vst1q_s32(sums_ptr, sum);
+ } else {
+ assert(Width == 2);
+ int32x2_t sum = vld1_s32(sums_ptr);
+ int16x4_t sums8 =
+ vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0]));
+ sum = vpadal_s16(sum, sums8);
+ vst1_s32(sums_ptr, sum);
+ }
+ dst->seek_forward_n_cells(1);
+ }
+};
+
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_PACK_NEON_H_
diff --git a/internal/platform.h b/internal/platform.h
index 1114767..ab71414 100644
--- a/internal/platform.h
+++ b/internal/platform.h
@@ -18,6 +18,7 @@
#define GEMMLOWP_INTERNAL_PLATFORM_H_
#ifdef _WIN32
+#include <malloc.h>
#include <windows.h>
#else
#include <stdlib.h>
@@ -71,8 +72,8 @@ inline int GetHardwareConcurrency(int max_threads) {
inline double real_time_in_seconds() {
__int64 wintime;
GetSystemTimeAsFileTime((FILETIME *)&wintime);
- wintime -= 116444736000000000i64; // 1jan1601 to 1jan1970
- return wintime / 10000000i64 + wintime % 10000000i64 * 100 * 1e-9;
+ wintime -= 116444736000000000LL; // 1jan1601 to 1jan1970
+ return wintime / 10000000LL + wintime % 10000000LL * 100 * 1e-9;
}
#else
diff --git a/internal/simd_wrappers.h b/internal/simd_wrappers.h
index d9721c9..4e4cce8 100644
--- a/internal/simd_wrappers.h
+++ b/internal/simd_wrappers.h
@@ -105,10 +105,12 @@ struct FlipLhsRhs {
using FlippedRhsType = RhsType;
static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
const RhsType& rhs) {
+ (void)rhs;
return lhs;
}
static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
const RhsType& rhs) {
+ (void)lhs;
return rhs;
}
};
@@ -119,10 +121,12 @@ struct FlipLhsRhs<LhsType, RhsType, true> {
using FlippedRhsType = LhsType;
static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
const RhsType& rhs) {
+ (void)lhs;
return rhs;
}
static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
const RhsType& rhs) {
+ (void)rhs;
return lhs;
}
};
@@ -192,6 +196,153 @@ typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
}
template <typename Lhs, typename Rhs>
+struct BroadcastShiftLeftImpl {
+ using ResultBlockType =
+ typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+ static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+ ResultBlockType result;
+ static constexpr int Rows = ResultBlockType::kRows;
+ static constexpr int Cols = ResultBlockType::kCols;
+ static constexpr int LhsRows = Lhs::kRows;
+ static constexpr int LhsCols = Lhs::kCols;
+ static constexpr int RhsRows = Rhs::kRows;
+ static constexpr int RhsCols = Rhs::kCols;
+
+ static_assert(LhsRows == Rows || LhsRows == 1, "");
+ static_assert(RhsRows == Rows || RhsRows == 1, "");
+ static_assert(LhsCols == Cols || LhsCols == 1, "");
+ static_assert(RhsCols == Cols || RhsCols == 1, "");
+ static_assert(ResultBlockType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Lhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Rhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ for (int c = 0; c < Cols; c++) {
+ const int lhs_c = LhsCols == Cols ? c : 0;
+ const int rhs_c = RhsCols == Cols ? c : 0;
+ for (int r = 0; r < Rows; r++) {
+ const int lhs_r = LhsRows == Rows ? r : 0;
+ const int rhs_r = RhsRows == Rows ? r : 0;
+ result.buf.reg[r + c * Rows] =
+ ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+ rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+ }
+ }
+ return result;
+ }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft(
+ const Lhs& lhs, const Rhs& rhs) {
+ using Flip = FlipLhsRhs<Lhs, Rhs>;
+ return BroadcastShiftLeftImpl<
+ typename Flip::FlippedLhsType,
+ typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+ Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl {
+ using ResultBlockType =
+ typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+ static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+ ResultBlockType result;
+ static constexpr int Rows = ResultBlockType::kRows;
+ static constexpr int Cols = ResultBlockType::kCols;
+ static constexpr int LhsRows = Lhs::kRows;
+ static constexpr int LhsCols = Lhs::kCols;
+ static constexpr int RhsRows = Rhs::kRows;
+ static constexpr int RhsCols = Rhs::kCols;
+
+ static_assert(LhsRows == Rows || LhsRows == 1, "");
+ static_assert(RhsRows == Rows || RhsRows == 1, "");
+ static_assert(LhsCols == Cols || LhsCols == 1, "");
+ static_assert(RhsCols == Cols || RhsCols == 1, "");
+ static_assert(ResultBlockType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Lhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Rhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ for (int c = 0; c < Cols; c++) {
+ const int lhs_c = LhsCols == Cols ? c : 0;
+ const int rhs_c = RhsCols == Cols ? c : 0;
+ for (int r = 0; r < Rows; r++) {
+ const int lhs_r = LhsRows == Rows ? r : 0;
+ const int rhs_r = RhsRows == Rows ? r : 0;
+ result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+ rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+ }
+ }
+ return result;
+ }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
+BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) {
+ using Flip = FlipLhsRhs<Lhs, Rhs>;
+ return BroadcastSaturatingRoundingDoublingHighMulImpl<
+ typename Flip::FlippedLhsType,
+ typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+ Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs>
+struct BroadcastRoundingDivideByPOTImpl {
+ using ResultBlockType =
+ typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
+ static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
+ ResultBlockType result;
+ static constexpr int Rows = ResultBlockType::kRows;
+ static constexpr int Cols = ResultBlockType::kCols;
+ static constexpr int LhsRows = Lhs::kRows;
+ static constexpr int LhsCols = Lhs::kCols;
+ static constexpr int RhsRows = Rhs::kRows;
+ static constexpr int RhsCols = Rhs::kCols;
+
+ static_assert(LhsRows == Rows || LhsRows == 1, "");
+ static_assert(RhsRows == Rows || RhsRows == 1, "");
+ static_assert(LhsCols == Cols || LhsCols == 1, "");
+ static_assert(RhsCols == Cols || RhsCols == 1, "");
+ static_assert(ResultBlockType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Lhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+ static_assert(Rhs::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ for (int c = 0; c < Cols; c++) {
+ const int lhs_c = LhsCols == Cols ? c : 0;
+ const int rhs_c = RhsCols == Cols ? c : 0;
+ for (int r = 0; r < Rows; r++) {
+ const int lhs_r = LhsRows == Rows ? r : 0;
+ const int rhs_r = RhsRows == Rows ? r : 0;
+ result.buf.reg[r + c * Rows] =
+ RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
+ rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
+ }
+ }
+ return result;
+ }
+};
+
+template <typename Lhs, typename Rhs>
+typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
+BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) {
+ using Flip = FlipLhsRhs<Lhs, Rhs>;
+ return BroadcastRoundingDivideByPOTImpl<
+ typename Flip::FlippedLhsType,
+ typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
+ Flip::FlippedRhs(lhs, rhs));
+}
+
+template <typename Lhs, typename Rhs>
struct BroadcastMulImpl {
using ResultBlockType =
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
@@ -494,12 +645,16 @@ template <int N>
using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
template <int N>
using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
+template <int N>
+using RegBufferInt8 = RegisterBuffer<std::int8_t, N>;
template <int R, int C>
using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
template <int R, int C>
using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
template <int R, int C>
using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
+template <int R, int C>
+using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>;
} // end namespace gemmlowp
diff --git a/internal/simd_wrappers_common_neon_sse.h b/internal/simd_wrappers_common_neon_sse.h
index 3830eb1..694bf99 100644
--- a/internal/simd_wrappers_common_neon_sse.h
+++ b/internal/simd_wrappers_common_neon_sse.h
@@ -350,6 +350,210 @@ struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
}
};
+// 4x1 := 4x1 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<4, 1> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 1x4 := 1x4 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<1, 4> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 4x1 := 4x1 + 4x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>,
+ RegBlockInt32<4, 1>> {
+ static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+ const RegBlockInt32<4, 1>& rhs) {
+ RegBlockInt32<4, 1> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 1x4 := 1x4 + 1x4
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>,
+ RegBlockInt32<1, 4>> {
+ static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<1, 4> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 4x4 := 4x4 + 1x4
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
+ RegBlockInt32<1, 4>> {
+ static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<4, 4> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 4x4 := 4x4 + 4x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>,
+ RegBlockInt32<4, 1>> {
+ static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+ const RegBlockInt32<4, 1>& rhs) {
+ RegBlockInt32<4, 4> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]);
+ result.buf.reg[2] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
+ result.buf.reg[3] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 8x1 := 8x1 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<8, 1> result;
+ const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
+ for (int i = 0; i < 2; i++) {
+ result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p);
+ }
+ return result;
+ }
+};
+
+// 8x1 := 8x1 + 8x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>,
+ RegBlockInt32<8, 1>> {
+ static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+ const RegBlockInt32<8, 1>& rhs) {
+ RegBlockInt32<8, 1> result;
+ for (int i = 0; i < 2; i++) {
+ result.buf.reg[i] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]);
+ }
+ return result;
+ }
+};
+
+// 8x4 := 8x4 + 1x4
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
+ RegBlockInt32<1, 4>> {
+ static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<8, 4> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[2] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[3] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[4] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[5] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[6] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
+ result.buf.reg[7] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 8x4 := 8x4 + 8x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>,
+ RegBlockInt32<8, 1>> {
+ static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+ const RegBlockInt32<8, 1>& rhs) {
+ RegBlockInt32<8, 4> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
+ result.buf.reg[2] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]);
+ result.buf.reg[3] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]);
+ result.buf.reg[4] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]);
+ result.buf.reg[5] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]);
+ result.buf.reg[6] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]);
+ result.buf.reg[7] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]);
+ return result;
+ }
+};
+
+// 1x8 := 1x8 + 1x8
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
+ RegBlockInt32<1, 8>> {
+ static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+ const RegBlockInt32<1, 8>& rhs) {
+ RegBlockInt32<1, 8> result;
+ result.buf.reg[0] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] =
+ SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]);
+ return result;
+ }
+};
+
+// 1x8 := 1x8 + 1x1
+template <>
+struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<1, 8> result;
+ result.buf.reg[0] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ result.buf.reg[1] = SaturatingRoundingDoublingHighMul(
+ lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
// 4x1 := 4x1 * 1x1
template <>
struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
diff --git a/internal/simd_wrappers_msa.h b/internal/simd_wrappers_msa.h
index cf5e8e9..7de01ff 100644
--- a/internal/simd_wrappers_msa.h
+++ b/internal/simd_wrappers_msa.h
@@ -33,8 +33,7 @@ struct RegisterType<std::int32_t, ScalarCount> {
template <int ScalarCount>
struct RegisterType<std::int16_t, ScalarCount> {
- using Type =
- typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
+ using Type = typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
};
template <int ScalarCount>
@@ -69,13 +68,9 @@ inline Int16x8 LoadInt16x8(const Int16x8* src) {
return __builtin_msa_ld_h(const_cast<Int16x8*>(src), 0);
}
-inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
- __builtin_msa_st_h(value, dst, 0);
-}
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); }
-inline void StoreInt16x8(Int16x8* dst, Int16x8 value) {
- __builtin_msa_st_h(value, dst, 0);
-}
+inline void StoreInt16x8(Int16x8* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); }
inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
return __builtin_msa_ld_b(const_cast<std::uint8_t*>(src), 0);
diff --git a/internal/simd_wrappers_neon.h b/internal/simd_wrappers_neon.h
index 2949173..6871055 100644
--- a/internal/simd_wrappers_neon.h
+++ b/internal/simd_wrappers_neon.h
@@ -25,6 +25,7 @@ using Int32x4 = int32x4_t;
using Int16x4 = int16x4_t;
using Int16x8 = int16x8_t;
using Uint8x8 = uint8x8_t;
+using Int8x8 = int8x8_t;
template <int ScalarCount>
struct RegisterType<std::int32_t, ScalarCount> {
@@ -48,6 +49,14 @@ struct RegisterType<std::uint8_t, ScalarCount> {
std::uint8_t>::type>::type;
};
+template <int ScalarCount>
+struct RegisterType<std::int8_t, ScalarCount> {
+ using Type = typename std::conditional<
+ ScalarCount >= 8, Int8x8,
+ typename std::conditional<ScalarCount >= 4, std::int32_t,
+ std::int8_t>::type>::type;
+};
+
inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); }
inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); }
@@ -92,6 +101,10 @@ inline Int32x4 Min(Int32x4 a, Int32x4 b) { return vminq_s32(a, b); }
inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); }
+inline Int32x4 Max(Int32x4 a, std::int32_t b) {
+ return vmaxq_s32(a, vdupq_n_s32(b));
+}
+
inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
return vqrdmulhq_n_s32(a, b);
}
@@ -164,6 +177,17 @@ struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
};
template <>
+struct LoadContiguousImpl<RegBlockInt8<8, 8>> {
+ static RegBlockInt8<8, 8> Run(const std::int8_t* src) {
+ RegBlockInt8<8, 8> result;
+ for (int i = 0; i < 8; i++) {
+ result.buf.reg[i] = vld1_s8(src + 8 * i);
+ }
+ return result;
+ }
+};
+
+template <>
struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
RegBlockInt32<8, 8> result;
@@ -174,6 +198,352 @@ struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
}
};
+// 4x1 := 4x1 + 1x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
+ static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<4, 1> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 1x4 := 1x4 + 1x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
+ static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<1, 4> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 4x1 := 4x1 + 4x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
+ static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+ const RegBlockInt32<4, 1>& rhs) {
+ RegBlockInt32<4, 1> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 1x4 := 1x4 + 1x4
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
+ static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<1, 4> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 4x4 := 4x4 + 1x4
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
+ static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<4, 4> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 4x4 := 4x4 + 4x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
+ static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+ const RegBlockInt32<4, 1>& rhs) {
+ RegBlockInt32<4, 4> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[0]);
+ result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]);
+ result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 8x1 := 8x1 + 1x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
+ static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<8, 1> result;
+ const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
+ for (int i = 0; i < 2; i++) {
+ result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], p);
+ }
+ return result;
+ }
+};
+
+// 8x1 := 8x1 + 8x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
+ static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+ const RegBlockInt32<8, 1>& rhs) {
+ RegBlockInt32<8, 1> result;
+ for (int i = 0; i < 2; i++) {
+ result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], rhs.buf.reg[i]);
+ }
+ return result;
+ }
+};
+
+// 8x4 := 8x4 + 1x4
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
+ static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<8, 4> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
+ result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 8x4 := 8x4 + 8x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
+ static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+ const RegBlockInt32<8, 1>& rhs) {
+ RegBlockInt32<8, 4> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]);
+ result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]);
+ result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[1]);
+ result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], rhs.buf.reg[0]);
+ result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], rhs.buf.reg[1]);
+ result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], rhs.buf.reg[0]);
+ result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], rhs.buf.reg[1]);
+ return result;
+ }
+};
+
+// 1x8 := 1x8 + 1x8
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
+ static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+ const RegBlockInt32<1, 8>& rhs) {
+ RegBlockInt32<1, 8> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]);
+ return result;
+ }
+};
+
+// 1x8 := 1x8 + 1x1
+template <>
+struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
+ static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<1, 8> result;
+ result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 4x1 := 4x1 + 1x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<4, 1> result;
+ result.buf.reg[0] =
+ RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 1x4 := 1x4 + 1x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<1, 4> result;
+ result.buf.reg[0] =
+ RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 4x1 := 4x1 + 4x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>,
+ RegBlockInt32<4, 1>> {
+ static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
+ const RegBlockInt32<4, 1>& rhs) {
+ RegBlockInt32<4, 1> result;
+ result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 1x4 := 1x4 + 1x4
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>,
+ RegBlockInt32<1, 4>> {
+ static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<1, 4> result;
+ result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 4x4 := 4x4 + 1x4
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>,
+ RegBlockInt32<1, 4>> {
+ static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<4, 4> result;
+ result.buf.reg[0] =
+ RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[1] =
+ RoundingDivideByPOT(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[2] =
+ RoundingDivideByPOT(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[3] =
+ RoundingDivideByPOT(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 4x4 := 4x4 + 4x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>,
+ RegBlockInt32<4, 1>> {
+ static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
+ const RegBlockInt32<4, 1>& rhs) {
+ RegBlockInt32<4, 4> result;
+ result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[0]);
+ result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]);
+ result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[0]);
+ return result;
+ }
+};
+
+// 8x1 := 8x1 + 1x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<8, 1> result;
+ const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
+ for (int i = 0; i < 2; i++) {
+ result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], p);
+ }
+ return result;
+ }
+};
+
+// 8x1 := 8x1 + 8x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>,
+ RegBlockInt32<8, 1>> {
+ static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
+ const RegBlockInt32<8, 1>& rhs) {
+ RegBlockInt32<8, 1> result;
+ for (int i = 0; i < 2; i++) {
+ result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], rhs.buf.reg[i]);
+ }
+ return result;
+ }
+};
+
+// 8x4 := 8x4 + 1x4
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>,
+ RegBlockInt32<1, 4>> {
+ static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+ const RegBlockInt32<1, 4>& rhs) {
+ RegBlockInt32<8, 4> result;
+ result.buf.reg[0] =
+ RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[1] =
+ RoundingDivideByPOT(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
+ result.buf.reg[2] =
+ RoundingDivideByPOT(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[3] =
+ RoundingDivideByPOT(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
+ result.buf.reg[4] =
+ RoundingDivideByPOT(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[5] =
+ RoundingDivideByPOT(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
+ result.buf.reg[6] =
+ RoundingDivideByPOT(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
+ result.buf.reg[7] =
+ RoundingDivideByPOT(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
+// 8x4 := 8x4 + 8x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>,
+ RegBlockInt32<8, 1>> {
+ static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
+ const RegBlockInt32<8, 1>& rhs) {
+ RegBlockInt32<8, 4> result;
+ result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]);
+ result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]);
+ result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[1]);
+ result.buf.reg[4] = RoundingDivideByPOT(lhs.buf.reg[4], rhs.buf.reg[0]);
+ result.buf.reg[5] = RoundingDivideByPOT(lhs.buf.reg[5], rhs.buf.reg[1]);
+ result.buf.reg[6] = RoundingDivideByPOT(lhs.buf.reg[6], rhs.buf.reg[0]);
+ result.buf.reg[7] = RoundingDivideByPOT(lhs.buf.reg[7], rhs.buf.reg[1]);
+ return result;
+ }
+};
+
+// 1x8 := 1x8 + 1x8
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>,
+ RegBlockInt32<1, 8>> {
+ static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+ const RegBlockInt32<1, 8>& rhs) {
+ RegBlockInt32<1, 8> result;
+ result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]);
+ result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]);
+ return result;
+ }
+};
+
+// 1x8 := 1x8 + 1x1
+template <>
+struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>,
+ RegBlockInt32<1, 1>> {
+ static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
+ const RegBlockInt32<1, 1>& rhs) {
+ RegBlockInt32<1, 8> result;
+ result.buf.reg[0] =
+ RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
+ result.buf.reg[1] =
+ RoundingDivideByPOT(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
+ return result;
+ }
+};
+
} // end namespace gemmlowp
#include "simd_wrappers_common_neon_sse.h"
diff --git a/internal/unpack.h b/internal/unpack.h
index 33aee13..021f4aa 100644
--- a/internal/unpack.h
+++ b/internal/unpack.h
@@ -98,12 +98,14 @@ void UnpackResultBlock(const SrcMapType& src,
const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
int depth, int src_row, int src_col, int src_global_row,
int src_global_col, int dst_row, int dst_col) {
+ using KernelLhsInputScalar = typename KernelFormat::Lhs::InputScalar;
using KernelLhsScalar = typename KernelFormat::Lhs::Scalar;
+ using KernelRhsInputScalar = typename KernelFormat::Rhs::InputScalar;
using KernelRhsScalar = typename KernelFormat::Rhs::Scalar;
static constexpr int KernelLhsZeroPointInput =
- ZeroPointInputValue<KernelLhsScalar>::kValue;
+ ZeroPointInputValue<KernelLhsInputScalar, KernelLhsScalar>::kValue;
static constexpr int KernelRhsZeroPointInput =
- ZeroPointInputValue<KernelRhsScalar>::kValue;
+ ZeroPointInputValue<KernelRhsInputScalar, KernelRhsScalar>::kValue;
auto acc = Load<RegisterBlockType>(src, src_row, src_col);
const auto& lhs_sums_of_each_slice_block =
LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row);
diff --git a/meta/multi_thread_common.h b/meta/multi_thread_common.h
index 0b35759..b39c3f2 100644
--- a/meta/multi_thread_common.h
+++ b/meta/multi_thread_common.h
@@ -22,9 +22,15 @@ namespace meta {
inline int ResolveMaxThreads(int max_threads) {
if (max_threads == 0) {
+#ifdef _WIN32
+ SYSTEM_INFO sysinfo;
+ GetSystemInfo(&sysinfo);
+ return sysinfo.dwNumberOfProcessors;
+#else
static const int hardware_threads_count =
static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
return hardware_threads_count;
+#endif
}
return max_threads;
}
diff --git a/profiling/instrumentation.h b/profiling/instrumentation.h
index 437fe54..c1f852e 100644
--- a/profiling/instrumentation.h
+++ b/profiling/instrumentation.h
@@ -108,13 +108,14 @@ struct ScopedLock {
// contains pointers to literal strings that were manually entered
// in the instrumented code (see ScopedProfilingLabel).
struct ProfilingStack {
- static const std::size_t kMaxSize = 14;
+ static const std::size_t kMaxSize = 30;
typedef const char* LabelsArrayType[kMaxSize];
LabelsArrayType labels;
std::size_t size;
Mutex* lock;
ProfilingStack() { memset(this, 0, sizeof(ProfilingStack)); }
+ ~ProfilingStack() { delete lock; }
void Push(const char* label) {
ScopedLock sl(lock);
@@ -171,8 +172,6 @@ struct ThreadInfo {
ScopedLock sl(GlobalMutexes::Profiler());
ThreadInfo* self = static_cast<ThreadInfo*>(ptr);
ThreadsUnderProfiling().erase(self);
- pthread_key_delete(self->key);
- delete self->stack.lock;
}
};
@@ -185,7 +184,11 @@ inline ThreadInfo& ThreadLocalThreadInfo() {
}
};
- static int key_result = pthread_key_create(&key, DeleteThreadInfo);
+ // key_result is unused. The purpose of this 'static' local object is
+ // to have its initializer (the pthread_key_create call) performed exactly
+ // once, in a way that is guaranteed (since C++11) to be reentrant.
+ static const int key_result = pthread_key_create(&key, DeleteThreadInfo);
+ (void)key_result;
ThreadInfo* threadInfo = static_cast<ThreadInfo*>(pthread_getspecific(key));
if (!threadInfo) {
diff --git a/profiling/pthread_everywhere.h b/profiling/pthread_everywhere.h
index df17c6f..2569bbc 100644
--- a/profiling/pthread_everywhere.h
+++ b/profiling/pthread_everywhere.h
@@ -60,6 +60,9 @@ inline void pthread_cond_init(pthread_cond_t *cond, std::nullptr_t) {
*cond = new std::condition_variable;
}
inline void pthread_cond_signal(pthread_cond_t *cond) { (*cond)->notify_one(); }
+inline void pthread_cond_broadcast(pthread_cond_t *cond) {
+ (*cond)->notify_all();
+}
inline void pthread_cond_wait(pthread_cond_t *cond, pthread_mutex_t *mutex) {
std::unique_lock<std::mutex> lock(**mutex, std::adopt_lock);
(*cond)->wait(lock);
diff --git a/public/bit_depth.h b/public/bit_depth.h
index 6cb4ecf..412944e 100644
--- a/public/bit_depth.h
+++ b/public/bit_depth.h
@@ -24,14 +24,15 @@ template <int tMinValue, int tMaxValue>
struct OperandRange {
static const int kMinValue = tMinValue;
static const int kMaxValue = tMaxValue;
- static_assert(0 <= kMinValue, "");
static_assert(kMinValue < kMaxValue, "");
- static_assert(kMaxValue <= 255, "");
};
using Uint8Range = OperandRange<0, 255>;
using Uint8RangeExcludingZero = OperandRange<1, 255>;
+using Int8Range = OperandRange<-128, 127>;
+using Int8RangeExcludingLow = OperandRange<-127, 127>;
+
template <typename tLhsRange, typename tRhsRange>
struct BitDepthParams {
using LhsRange = tLhsRange;
@@ -47,6 +48,11 @@ using DefaultL8R8BitDepthParams = BitDepthParams<Uint8Range, Uint8Range>;
using L8R8WithLhsNonzeroBitDepthParams =
BitDepthParams<Uint8RangeExcludingZero, Uint8Range>;
+// Signed Variant: This allows using faster kernels using signed arithmetic, see
+// NEON_64bit_GEMM_Int8Operands_Int32Accumulators_AccumTwoWithin16Bits
+using SignedL8R8WithLhsNonzeroBitDepthParams =
+ BitDepthParams<Int8RangeExcludingLow, Int8Range>;
+
// Deprecated: when gemmlowp used to allow requantizing 8bit
// inputs to less-than-8-bit depths, the public setting allowing
// that was DefaultL7R5BitDepthParams. That requantization
diff --git a/public/map.h b/public/map.h
index 3073e05..fe6bc5c 100644
--- a/public/map.h
+++ b/public/map.h
@@ -131,6 +131,7 @@ class VectorDup {
assert(start >= 0);
assert(start + len <= size_);
+ (void)start;
return VectorDup(data_, len);
}
};
diff --git a/public/output_stages.h b/public/output_stages.h
index 1d5fca4..797b662 100644
--- a/public/output_stages.h
+++ b/public/output_stages.h
@@ -138,12 +138,44 @@ struct OutputStageScaleInt32ByFixedPointAndExponent {
std::int32_t result_offset_after_shift;
};
+// Variant of OutputStageQuantizeDownInt32ByFixedPoint where the 'shift'
+// is not necessarily just a right shift, so we can represent multipliers
+// greater than 1. This takes an result_exponent parameter; when it's
+// <= 0, this is equivalent to OutputStageQuantizeDownInt32ByFixedPoint
+// with result_shift = -result_exponent.
+// In the general case, this consists in first left-shifting by
+// std::max(result_exponent, 0), before doing the same as
+// OutputStageQuantizeDownInt32ByFixedPoint with
+// result_shift = std::max(-result_exponent, 0).
+//
+// Difference from OutputStageScaleInt32ByFixedPointAndExponent here is that
+// each row or column of the output (depending on tShape) has its own
+// result_fixedpoint_multiplier and result_exponent numbers.
+template <VectorShape tShape>
+struct OutputStageScaleInt32ByFixedPointAndExponentPC {
+ VectorMap<const std::int32_t, tShape> result_fixedpoint_multiplier;
+ VectorMap<const std::int32_t, tShape> result_exponent;
+ std::int32_t result_offset_after_shift;
+};
+
// This output stage takes int32 values that are expected to be already
// on the final uint8 scale, but not necessarily in the [0..255] range.
// It clamps them to the [0..255] range and returns them casted to uint8.
struct OutputStageSaturatingCastToUint8 {};
// This output stage takes int32 values that are expected to be already
+// on the final int8 scale, but not necessarily in the [-128..127] range.
+// It clamps them to the [-128..127] range and returns them casted to int8.
+struct OutputStageSaturatingCastToInt8 {};
+
+// This output stage takes int32 values that are expected to be already
+// in the [0..255] range and returns them casted to uint8.
+// This stage can save time if used instead of the
+// OutputStageSaturatingCastToUint8 stage immediately after the
+// OutputStageClamp stage.
+struct OutputStageTruncatingCastToUint8 {};
+
+// This output stage takes int32 values that are expected to be already
// on the final int16 scale, but not necessarily in the [-32768..32767] range.
// It clamps them to the [-32768..32767] range and returns them casted to int16.
struct OutputStageSaturatingCastToInt16 {};