diff options
author | Miao Wang <miaowang@google.com> | 2019-08-27 17:50:04 -0700 |
---|---|---|
committer | android-build-merger <android-build-merger@google.com> | 2019-08-27 17:50:04 -0700 |
commit | 68dcc597e650eeda114b77f22e6391e85f4c5437 (patch) | |
tree | f910ae75e271bc79a22d9a73b94da6bdbe01d330 | |
parent | 032b4e313c03a94aac1769c0a81ae7d49943fe4b (diff) | |
parent | e8a1111f830a39e429ecbab08972c370fe9dcfb0 (diff) | |
download | gemmlowp-68dcc597e650eeda114b77f22e6391e85f4c5437.tar.gz |
Rebase gemmlowp to a227af1fdb47f250b5df07d6936366b0f8113b65 am: 70ba50cbca am: 36f90a2b7a am: 846c903a24
am: e8a1111f83
Change-Id: Ibd4d3c2ce93aec1001c278e313f98770d6b6676d
39 files changed, 3938 insertions, 560 deletions
diff --git a/doc/kernel.md b/doc/kernel.md index 261cb92..f3f2138 100644 --- a/doc/kernel.md +++ b/doc/kernel.md @@ -40,11 +40,15 @@ NEONKernel12x4Depth2 kernel, which specifies its format as The meaning of these terms is explained in the lengthy comment at the top of internal/kernel.h. Here, they mean that this kernel handles at each iteration -(along the depth dimension): - 3 'cells' of size 4x2 each of the lhs, so a total -lhs block of size 12x2 - 1 'cell' of size 2x4 of the rhs. In other words, this -kernel handles 12 rows of the lhs and 4 columns of the rhs, and handles two -levels of depth at once. The 'cells' and `CellFormat` detail the layout of these -12x2 and 2x4 blocks. +(along the depth dimension): + +- 3 'cells' of size 4x2 each of the lhs, so a total lhs block of size 12x2 + +- 1 'cell' of size 2x4 of the rhs. + +In other words, this kernel handles 12 rows of the lhs and 4 columns of the +rhs, and handles two levels of depth at once. The 'cells' and `CellFormat` +detail the layout of these 12x2 and 2x4 blocks. This kernel then loads these 12x2 and 2x4 blocks and computes the corresponding 12x4 GEMM; for ease of reference let us paste the critical comment and code diff --git a/doc/public.md b/doc/public.md index 935f6db..7739b85 100644 --- a/doc/public.md +++ b/doc/public.md @@ -14,7 +14,7 @@ The high-level overview of how this specifies a low-precision matrix multiplication is explained in [low-precision.md](low-precision.md). The rationale for a specific quantization paradigm is given in [quantization.md](quantization.md). That specific quantization paradigm is -implemented at two different stages of the computation: as pre-processing ont +implemented at two different stages of the computation: as pre-processing on the operands and as post-processing on the result: * Pre-processing on the LHS, RHS operands, in the form of adding constant @@ -56,7 +56,7 @@ being automatically deduced from function parameters: * `InputScalar`: The scalar type of the LHS and RHS operands. At the moment, this must be `std::uint8_t`. -* `OutputScalar`: The scalar type of the LHS and RHS operands. At the moment, +* `OutputScalar`: The scalar type of the result. At the moment, this must be `std::uint8_t`. * `BitDepthParams`: Defines the bit format of the input and output matrices and the required accuracy of the computation. At the moment, the only diff --git a/doc/quantization.md b/doc/quantization.md index 3a8f72b..e5055e7 100644 --- a/doc/quantization.md +++ b/doc/quantization.md @@ -13,7 +13,7 @@ quantization paradigm affects the calculations that gemmlowp itself needs to perform, specifically, it affects how one goes from internal 32bit accumulator to final 8bit outputs. -The part of gemmlowp transforming internal internal 32bit accumulator to final +The part of gemmlowp transforming internal 32bit accumulator to final 8bit outputs is the "output pipeline" described in [output.md](output.md). gemmlowp's `GemmWithOutputPipeline` entry point allows specifying an arbitrary diff --git a/eight_bit_int_gemm/eight_bit_int_gemm.cc b/eight_bit_int_gemm/eight_bit_int_gemm.cc index 512c483..a8d9b43 100644 --- a/eight_bit_int_gemm/eight_bit_int_gemm.cc +++ b/eight_bit_int_gemm/eight_bit_int_gemm.cc @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -#endif #include "eight_bit_int_gemm.h" #include <memory> diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h index d39341b..58e8050 100644 --- a/fixedpoint/fixedpoint.h +++ b/fixedpoint/fixedpoint.h @@ -18,10 +18,13 @@ #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ +#include <algorithm> #include <cassert> +#include <cmath> +#include <cstdint> #include <limits> -#include "../internal/common.h" +#include "../internal/detect_platform.h" namespace gemmlowp { @@ -47,13 +50,13 @@ struct FixedPointRawTypeTraits {}; template <> struct FixedPointRawTypeTraits<std::int32_t> { typedef std::int32_t ScalarRawType; - static const int kLanes = 1; + static constexpr int kLanes = 1; }; template <> struct FixedPointRawTypeTraits<std::int16_t> { typedef std::int16_t ScalarRawType; - static const int kLanes = 1; + static constexpr int kLanes = 1; }; // Returns a SIMD value duplicating a scalar value across all lanes. @@ -109,11 +112,25 @@ tIntegerType Neg(tIntegerType a) { return -a; } -// Integer arithmetic left-shift, equivalent to multiplying with a -// power of two. Not saturating. Overflow is undefined behavior. -template <typename tIntegerType> -tIntegerType ShiftLeft(tIntegerType a, int offset) { - return a << offset; +// Integer arithmetic left-shift, equivalent to multiplying with a power of two. +// Negative values are OK. In case of overflow, no Undefined +// Behavior, but the results are implementation-defined (in practice, +// they currently are saturated, but we make no commitment to that). The idea +// is that the caller will want to implement the overflowing cases with +// saturation with compare-and-mask, so we don't care about the results +// in the overflow case, we just want to avoid undefined behavior. +// +// tIntegerType may be int32 or any narrower signed type. +template <typename tIntegerType, typename OffsetType> +tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) { + const std::int64_t wide_a = static_cast<std::int64_t>(a); + const std::int64_t wide_shifted = wide_a * (1 << offset); + const auto min = std::numeric_limits<tIntegerType>::min(); + const auto max = std::numeric_limits<tIntegerType>::max(); + return wide_shifted < min + ? min + : wide_shifted > max ? max + : static_cast<tIntegerType>(wide_shifted); } // Integer arithmetic right-shift. Not rounding. @@ -137,7 +154,7 @@ tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, // input scalar is non-zero. template <typename tIntegerType> tIntegerType MaskIfNonZero(tIntegerType a) { - static const tIntegerType zero = 0; + static constexpr tIntegerType zero = 0; return a ? BitNot(zero) : zero; } @@ -211,6 +228,7 @@ bool Any(tIntegerType a) { template <typename IntegerType> IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); + (void)b; return a; } @@ -235,6 +253,7 @@ inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { template <typename IntegerType> IntegerType SaturatingAdd(IntegerType a, IntegerType b) { static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); + (void)b; return a; } @@ -244,7 +263,9 @@ inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { std::int32_t a32 = a; std::int32_t b32 = b; std::int32_t sum = a32 + b32; - return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum))); + return static_cast<std::int16_t>( + std::min(static_cast<std::int32_t>(32767), + std::max(static_cast<std::int32_t>(-32768), sum))); } // Returns a+b, saturating if the integers are 16bit or narrower, @@ -298,6 +319,7 @@ IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { template <typename IntegerType> IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); + (void)b; return a; } @@ -331,8 +353,8 @@ inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, // Correctly-rounded-to-nearest division by a power-of-two. // Also known as a rounding arithmetic right shift. -template <typename IntegerType> -inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) { +template <typename IntegerType, typename ExponentType> +inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) { assert(exponent >= 0); assert(exponent <= 31); const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); @@ -432,9 +454,9 @@ class FixedPoint { typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; typedef typename RawTypeTraits::ScalarRawType ScalarRawType; - static const int kTotalBits = 8 * sizeof(ScalarRawType); - static const int kIntegerBits = tIntegerBits; - static const int kFractionalBits = kTotalBits - 1 - kIntegerBits; + static constexpr int kTotalBits = 8 * sizeof(ScalarRawType); + static constexpr int kIntegerBits = tIntegerBits; + static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); @@ -474,7 +496,7 @@ class FixedPoint { template <int Exponent> static FixedPoint ConstantPOT() { - static const int kOffset = kFractionalBits + Exponent; + static constexpr int kOffset = kFractionalBits + Exponent; static_assert( kOffset < 31, "Constant not exactly representable in this fixed-point format"); @@ -645,7 +667,7 @@ double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> FixedPoint<tRawType, tIntegerBitsDst> Rescale( FixedPoint<tRawType, tIntegerBitsSrc> x) { - static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst; + static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; FixedPoint<tRawType, tIntegerBitsDst> result; result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); return result; @@ -725,9 +747,9 @@ FixedPoint<tRawType, 0> exp_on_negative_values( FixedPoint<tRawType, tIntegerBits> a) { typedef FixedPoint<tRawType, tIntegerBits> InputF; typedef FixedPoint<tRawType, 0> ResultF; - static const int kFractionalBits = InputF::kFractionalBits; - static const int kIntegerBits = InputF::kIntegerBits; - static const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); + static constexpr int kFractionalBits = InputF::kFractionalBits; + static constexpr int kIntegerBits = InputF::kIntegerBits; + const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); InputF mask = kOneQuarter - InputF::FromScalarRaw(1); InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( @@ -755,10 +777,10 @@ FixedPoint<tRawType, 0> exp_on_negative_values( #undef GEMMLOWP_EXP_BARREL_SHIFTER + static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0; if (kIntegerBits > 5) { - static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0; const InputF clamp = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0); result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); } @@ -867,6 +889,8 @@ FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { #ifdef GEMMLOWP_NEON #include "./fixedpoint_neon.h" +#elif defined(GEMMLOWP_AVX2) +#include "./fixedpoint_avx.h" #elif defined(GEMMLOWP_SSE4) #include "./fixedpoint_sse.h" #elif defined(GEMMLOWP_MSA) diff --git a/fixedpoint/fixedpoint_avx.h b/fixedpoint/fixedpoint_avx.h new file mode 100644 index 0000000..1816386 --- /dev/null +++ b/fixedpoint/fixedpoint_avx.h @@ -0,0 +1,218 @@ +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// fixedpoint_avx.h: optimized avx specializations of the templates +// in fixedpoint.h. + +#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ +#define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ + +#include <smmintrin.h> +#include "fixedpoint.h" +#include "fixedpoint_sse.h" + +namespace gemmlowp { + +template <> +struct FixedPointRawTypeTraits<__m256i> { + typedef std::int32_t ScalarRawType; + static const int kLanes = 4; +}; + +template <> +inline __m256i BitAnd(__m256i a, __m256i b) { + return _mm256_and_si256(a, b); +} + +template <> +inline __m256i BitOr(__m256i a, __m256i b) { + return _mm256_or_si256(a, b); +} + +template <> +inline __m256i BitXor(__m256i a, __m256i b) { + return _mm256_xor_si256(a, b); +} + +template <> +inline __m256i BitNot(__m256i a) { + return _mm256_andnot_si256(a, _mm256_set1_epi32(-1)); +} + +template <> +inline __m256i Add(__m256i a, __m256i b) { + return _mm256_add_epi32(a, b); +} + +template <> +inline __m256i Mul(__m256i a, __m256i b) { + return _mm256_mullo_epi32(a, b); +} + +template <> +inline __m256i Sub(__m256i a, __m256i b) { + return _mm256_sub_epi32(a, b); +} + +template <> +inline __m256i Neg(__m256i a) { + return _mm256_sign_epi32(a, _mm256_set1_epi32(-1)); +} + +template <> +inline __m256i ShiftLeft(__m256i a, int offset) { + return _mm256_slli_epi32(a, offset); +} + +template <> +inline __m256i ShiftRight(__m256i a, int offset) { + return _mm256_srai_epi32(a, offset); +} + +template <> +inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val, + __m256i else_val) { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(else_val), + _mm256_castsi256_ps(then_val), + _mm256_castsi256_ps(if_mask))); +} + +template <> +inline __m256i MaskIfEqual(__m256i a, __m256i b) { + return _mm256_cmpeq_epi32(a, b); +} + +template <> +inline __m256i MaskIfNotEqual(__m256i a, __m256i b) { + return BitNot(MaskIfEqual(a, b)); +} + +template <> +inline __m256i MaskIfZero(__m256i a) { + return MaskIfEqual(a, _mm256_set1_epi32(0)); +} + +template <> +inline __m256i MaskIfNonZero(__m256i a) { + return MaskIfNotEqual(a, _mm256_set1_epi32(0)); +} + +template <> +inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) { + return _mm256_cmpgt_epi32(a, b); +} + +template <> +inline __m256i MaskIfLessThan(__m256i a, __m256i b) { + return _mm256_cmpgt_epi32(b, a); +} + +template <> +inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) { + return BitNot(MaskIfLessThan(a, b)); +} + +template <> +inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) { + return BitNot(MaskIfGreaterThan(a, b)); +} + +/* Assumptions: + - All and Any are used on masks. + - masks are all_ones for true lanes, all_zeroes otherwise. +Hence, All means all 128bits set, and Any means any bit set. +*/ + +template <> +inline bool All(__m256i a) { + return _mm256_testc_si256(a, a); +} + +template <> +inline bool Any(__m256i a) { + return BitNot(_mm256_testz_si256(a, a)); +} + +template <> +inline __m256i RoundingHalfSum(__m256i a, __m256i b) { + /* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */ + /* We divide the inputs before the add to avoid the overflow and costly test + */ + /* of checking if an overflow occured on signed add */ + /* round_bit_mask = _mm_set1_epi32(1); */ + /* a_over_2 = _mm_srai_epi32(a, 1); */ + /* b_over_2 = _mm_srai_epi32(b, 1); */ + /* sum = Add(a_over_2, b_over_2); */ + /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */ + /* return Add(sum, round_bit); */ + + /* Other possibility detecting overflow and xor the sign if an overflow + * happened*/ + __m256i one, sign_bit_mask, sum, rounded_half_sum, overflow, result; + one = _mm256_set1_epi32(1); + sign_bit_mask = _mm256_set1_epi32(0x80000000); + sum = Add(a, b); + rounded_half_sum = _mm256_srai_epi32(Add(sum, one), 1); + overflow = + BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)), + sign_bit_mask); + result = BitXor(rounded_half_sum, overflow); + return result; +} + +template <> +inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) { + __m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3; + __m256i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded; + __m256i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result; + __m256i nudge; + + // saturation only happen if a == b == INT_MIN + min = _mm256_set1_epi32(std::numeric_limits<std::int32_t>::min()); + saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min)); + + // a = a0 | a1 | a2 | a3 + // b = b0 | b1 | b2 | b3 + a0_a2 = a; + a1_a3 = _mm256_srli_si256(a, 4); + b0_b2 = b; + b1_b3 = _mm256_srli_si256(b, 4); + + a0b0_a2b2 = _mm256_mul_epi32(a0_a2, b0_b2); + a1b1_a3b3 = _mm256_mul_epi32(a1_a3, b1_b3); + + // do the rounding and take into account that it will be doubled + nudge = _mm256_set1_epi64x(1 << 30); + a0b0_a2b2_rounded = _mm256_add_epi64(a0b0_a2b2, nudge); + a1b1_a3b3_rounded = _mm256_add_epi64(a1b1_a3b3, nudge); + + // do the doubling + a0b0_a2b2_rounded_2x = _mm256_slli_epi64(a0b0_a2b2_rounded, 1); + a1b1_a3b3_rounded_2x = _mm256_slli_epi64(a1b1_a3b3_rounded, 1); + + // get the high part of the products + result = _mm256_blend_epi16(_mm256_srli_si256(a0b0_a2b2_rounded_2x, 4), + a1b1_a3b3_rounded_2x, 0xcc); + + // saturate those which overflowed + return SelectUsingMask(saturation_mask, min, result); +} + +template <> +inline __m256i Dup<__m256i>(std::int32_t x) { + return _mm256_set1_epi32(x); +} + +} // end namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h index c7a110c..b17f32a 100644 --- a/fixedpoint/fixedpoint_msa.h +++ b/fixedpoint/fixedpoint_msa.h @@ -25,13 +25,13 @@ namespace gemmlowp { template <> struct FixedPointRawTypeTraits<v4i32> { typedef std::int32_t ScalarRawType; - static const int kLanes = 4; + static constexpr int kLanes = 4; }; template <> struct FixedPointRawTypeTraits<v8i16> { typedef std::int16_t ScalarRawType; - static const int kLanes = 8; + static constexpr int kLanes = 8; }; template <> @@ -326,11 +326,71 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> { } }; -// TODO: possibly implement: -// template <> v4i32 RoundingDivideByPOT(v4i32, int) -// template <> v8i16 RoundingDivideByPOT(v8i16, int) -// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> -// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> +template <int Exponent> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> { + static v4i32 eval(v4i32 x) { + static_assert(-31 <= Exponent && Exponent <= -1, ""); + // Isolate the sign bits. + v4i32 sign = __builtin_msa_srli_w(x, 31); + // Decrement the negative elements by 1 (with saturation). + x = __builtin_msa_subs_s_w(x, sign); + // Arithmetic shift right with rounding. + // The srari instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srari_w(x, -Exponent); + } +}; + +template <int Exponent> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> { + static v8i16 eval(v8i16 x) { + static_assert(-15 <= Exponent && Exponent <= -1, ""); + // Isolate the sign bits. + v8i16 sign = __builtin_msa_srli_h(x, 15); + // Decrement the negative elements by 1 (with saturation). + x = __builtin_msa_subs_s_h(x, sign); + // Arithmetic shift right with rounding. + // The srari instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srari_h(x, -Exponent); + } +}; + +template <> +inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) { + v4i32 e = __builtin_msa_fill_w(exponent); + // Isolate the sign bits. + v4i32 sign = __builtin_msa_srli_w(x, 31); + // Reset them to 0 if exponent is 0. + sign = __builtin_msa_min_s_w(sign, e); + // Decrement the negative elements by 1 (with saturation) + // if exponent is non-zero. + x = __builtin_msa_subs_s_w(x, sign); + // Arithmetic shift right with rounding. + // The srar instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srar_w(x, e); +} + +template <> +inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) { + v8i16 e = __builtin_msa_fill_h(exponent); + // Isolate the sign bits. + v8i16 sign = __builtin_msa_srli_h(x, 15); + // Reset them to 0 if exponent is 0. + sign = __builtin_msa_min_s_h(sign, e); + // Decrement the negative elements by 1 (with saturation) + // if exponent is non-zero. + x = __builtin_msa_subs_s_h(x, sign); + // Arithmetic shift right with rounding. + // The srar instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srar_h(x, e); +} template <> inline v4i32 Dup<v4i32>(std::int32_t x) { @@ -346,7 +406,6 @@ inline v8i16 Dup<v8i16>(std::int16_t x) { template <> inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) { return __builtin_msa_adds_s_h(a, b); - return a; } } // end namespace gemmlowp diff --git a/fixedpoint/fixedpoint_neon.h b/fixedpoint/fixedpoint_neon.h index 92b349b..4dab6c9 100644 --- a/fixedpoint/fixedpoint_neon.h +++ b/fixedpoint/fixedpoint_neon.h @@ -25,13 +25,13 @@ namespace gemmlowp { template <> struct FixedPointRawTypeTraits<int32x4_t> { typedef std::int32_t ScalarRawType; - static const int kLanes = 4; + static constexpr int kLanes = 4; }; template <> struct FixedPointRawTypeTraits<int16x8_t> { typedef std::int16_t ScalarRawType; - static const int kLanes = 8; + static constexpr int kLanes = 8; }; template <> @@ -115,6 +115,16 @@ inline int16x8_t ShiftLeft(int16x8_t a, int offset) { } template <> +inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) { + return vshlq_s32(a, offset); +} + +template <> +inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) { + return vshlq_s16(a, offset); +} + +template <> inline int32x4_t ShiftRight(int32x4_t a, int offset) { return vshlq_s32(a, vdupq_n_s32(-offset)); } @@ -282,6 +292,22 @@ inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) { return vrshlq_s16(fixed_up_x, shift_vec); } +template <> +inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) { + const int32x4_t shift_vec = vnegq_s32(exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +template <> +inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) { + const int16x8_t shift_vec = vnegq_s16(exponent); + const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15); + const int16x8_t fixed_up_x = vqaddq_s16(x, fixup); + return vrshlq_s16(fixed_up_x, shift_vec); +} + template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> { static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); } diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h index ba990f0..a1fae32 100644 --- a/fixedpoint/fixedpoint_sse.h +++ b/fixedpoint/fixedpoint_sse.h @@ -42,13 +42,13 @@ struct int16x8_m128i { template <> struct FixedPointRawTypeTraits<__m128i> { typedef std::int32_t ScalarRawType; - static const int kLanes = 4; + static constexpr int kLanes = 4; }; template <> struct FixedPointRawTypeTraits<int16x8_m128i> { typedef std::int16_t ScalarRawType; - static const int kLanes = 8; + static constexpr int kLanes = 8; }; template <> diff --git a/internal/common.h b/internal/common.h index 26b6713..332ad07 100644 --- a/internal/common.h +++ b/internal/common.h @@ -26,144 +26,9 @@ #include <cmath> #include <cstdlib> +#include "../internal/detect_platform.h" #include "../profiling/instrumentation.h" -// Our inline assembly path assume GCC/Clang syntax. -// Native Client doesn't seem to support inline assembly(?). -#if defined(__GNUC__) && !defined(__native_client__) -#define GEMMLOWP_ALLOW_INLINE_ASM -#endif - -// Define macro statement that avoids inlining for GCC. -// For non-GCC, define as empty macro. -#if defined(__GNUC__) -#define GEMMLOWP_NOINLINE __attribute__((noinline)) -#else -#define GEMMLOWP_NOINLINE -#endif - -// Detect ARM, 32-bit or 64-bit -#ifdef __arm__ -#define GEMMLOWP_ARM_32 -#endif - -#ifdef __aarch64__ -#define GEMMLOWP_ARM_64 -#endif - -#if defined(GEMMLOWP_ARM_32) || defined(GEMMLOWP_ARM_64) -#define GEMMLOWP_ARM -#endif - -// Detect MIPS, 32-bit or 64-bit -#if defined(__mips) && !defined(__LP64__) -#define GEMMLOWP_MIPS_32 -#endif - -#if defined(__mips) && defined(__LP64__) -#define GEMMLOWP_MIPS_64 -#endif - -#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64) -#define GEMMLOWP_MIPS -#endif - -// Detect x86, 32-bit or 64-bit -#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386) -#define GEMMLOWP_X86_32 -#endif - -#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64) -#define GEMMLOWP_X86_64 -#endif - -#if defined(GEMMLOWP_X86_32) || defined(GEMMLOWP_X86_64) -#define GEMMLOWP_X86 -#endif - -// Some of our optimized paths use inline assembly and for -// now we don't bother enabling some other optimized paths using intrinddics -// where we can't use inline assembly paths. -#ifdef GEMMLOWP_ALLOW_INLINE_ASM - -// Detect NEON. It's important to check for both tokens. -#if (defined __ARM_NEON) || (defined __ARM_NEON__) -#define GEMMLOWP_NEON -#endif - -// Convenience NEON tokens for 32-bit or 64-bit -#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_32) -#define GEMMLOWP_NEON_32 -#endif - -#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_64) -#define GEMMLOWP_NEON_64 -#endif - -// Detect MIPS MSA. -// Limit MSA optimizations to little-endian CPUs for now. -// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs? -#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \ - defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) -#define GEMMLOWP_MSA -#endif - -// Convenience MIPS MSA tokens for 32-bit or 64-bit. -#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32) -#define GEMMLOWP_MSA_32 -#endif - -#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64) -#define GEMMLOWP_MSA_64 -#endif - -// Detect SSE. -#ifdef __SSE4_1__ -#define GEMMLOWP_SSE4 -#endif - -#ifdef __SSE3__ -#define GEMMLOWP_SSE3 -#endif - -// Convenience SSE4 tokens for 32-bit or 64-bit -#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \ - !defined(GEMMLOWP_DISABLE_SSE4) -#define GEMMLOWP_SSE4_32 -#endif - -#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32) -#define GEMMLOWP_SSE3_32 -#endif - -#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \ - !defined(GEMMLOWP_DISABLE_SSE4) -#define GEMMLOWP_SSE4_64 -#endif - -#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64) -#define GEMMLOWP_SSE3_64 -#endif - -#if defined(__has_feature) -#if __has_feature(memory_sanitizer) -#include <sanitizer/msan_interface.h> -#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __msan_unpoison -#elif __has_feature(address_sanitizer) -#include <sanitizer/asan_interface.h> -#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __asan_unpoison_memory_region -#endif -#endif - -#endif // GEMMLOWP_ALLOW_INLINE_ASM - -// Detect Android. Don't conflate with ARM - we care about tuning -// for non-ARM Android devices too. This can be used in conjunction -// with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs. -#if defined(__ANDROID__) || defined(ANDROID) -#define GEMMLOWP_ANDROID -#endif - namespace gemmlowp { // Standard cache line size. Useful to optimize alignment and @@ -242,7 +107,12 @@ const float kDefaultL2RhsFactor = 0.75f; // size, so any size would work there. Different platforms may set this // to different values but must ensure that their own optimized packing paths // are consistent with this value. + +#ifdef GEMMLOWP_AVX2 +const int kRegisterSize = 32; +#else const int kRegisterSize = 16; +#endif // Hints the CPU to prefetch the cache line containing ptr. inline void Prefetch(const void* ptr) { diff --git a/internal/detect_platform.h b/internal/detect_platform.h new file mode 100644 index 0000000..6f06d19 --- /dev/null +++ b/internal/detect_platform.h @@ -0,0 +1,166 @@ +// Copyright 2018 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// detect_platform.h: Sets up macros that control architecture-specific +// features of gemmlowp's implementation. + +#ifndef GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_ +#define GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_ + +// Our inline assembly path assume GCC/Clang syntax. +// Native Client doesn't seem to support inline assembly(?). +#if defined(__GNUC__) && !defined(__native_client__) +#define GEMMLOWP_ALLOW_INLINE_ASM +#endif + +// Define macro statement that avoids inlining for GCC. +// For non-GCC, define as empty macro. +#if defined(__GNUC__) +#define GEMMLOWP_NOINLINE __attribute__((noinline)) +#else +#define GEMMLOWP_NOINLINE +#endif + +// Detect ARM, 32-bit or 64-bit +#ifdef __arm__ +#define GEMMLOWP_ARM_32 +#endif + +#ifdef __aarch64__ +#define GEMMLOWP_ARM_64 +#endif + +#if defined(GEMMLOWP_ARM_32) || defined(GEMMLOWP_ARM_64) +#define GEMMLOWP_ARM +#endif + +// Detect MIPS, 32-bit or 64-bit +#if defined(__mips) && !defined(__LP64__) +#define GEMMLOWP_MIPS_32 +#endif + +#if defined(__mips) && defined(__LP64__) +#define GEMMLOWP_MIPS_64 +#endif + +#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64) +#define GEMMLOWP_MIPS +#endif + +// Detect x86, 32-bit or 64-bit +#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386) +#define GEMMLOWP_X86_32 +#endif + +#if defined(__x86_64__) || defined(_M_X64) || defined(__amd64) +#define GEMMLOWP_X86_64 +#endif + +#if defined(GEMMLOWP_X86_32) || defined(GEMMLOWP_X86_64) +#define GEMMLOWP_X86 +#endif + +// Some of our optimized paths use inline assembly and for +// now we don't bother enabling some other optimized paths using intrinddics +// where we can't use inline assembly paths. +#ifdef GEMMLOWP_ALLOW_INLINE_ASM + +// Detect NEON. It's important to check for both tokens. +#if (defined __ARM_NEON) || (defined __ARM_NEON__) +#define GEMMLOWP_NEON +#endif + +// Convenience NEON tokens for 32-bit or 64-bit +#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_32) +#define GEMMLOWP_NEON_32 +#endif + +#if defined(GEMMLOWP_NEON) && defined(GEMMLOWP_ARM_64) +#define GEMMLOWP_NEON_64 +#endif + +// Detect MIPS MSA. +// Limit MSA optimizations to little-endian CPUs for now. +// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs? +#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \ + defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) +#define GEMMLOWP_MSA +#endif + +// Convenience MIPS MSA tokens for 32-bit or 64-bit. +#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32) +#define GEMMLOWP_MSA_32 +#endif + +#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64) +#define GEMMLOWP_MSA_64 +#endif + +// compiler define for AVX2 -D GEMMLOWP_ENABLE_AVX2 +// Detect AVX2 +#if defined(__AVX2__) && defined(GEMMLOWP_ENABLE_AVX2) +#define GEMMLOWP_AVX2 +// Detect SSE4. +// MSVC does not have __SSE4_1__ macro, but will enable SSE4 +// when AVX is turned on. +#elif defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__)) +#define GEMMLOWP_SSE4 +// Detect SSE3. +#elif defined(__SSE3__) +#define GEMMLOWP_SSE3 +#endif + +// Convenience SSE4 tokens for 32-bit or 64-bit +#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \ + !defined(GEMMLOWP_DISABLE_SSE4) +#define GEMMLOWP_SSE4_32 +#endif + +#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_32) +#define GEMMLOWP_SSE3_32 +#endif + +#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \ + !defined(GEMMLOWP_DISABLE_SSE4) +#define GEMMLOWP_SSE4_64 +#endif + +#if defined(GEMMLOWP_SSE3) && defined(GEMMLOWP_X86_64) +#define GEMMLOWP_SSE3_64 +#endif + +#if defined(GEMMLOWP_AVX2) && defined(GEMMLOWP_X86_64) +#define GEMMLOWP_AVX2_64 +#endif + +#if defined(__has_feature) +#if __has_feature(memory_sanitizer) +#include <sanitizer/msan_interface.h> +#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __msan_unpoison +#elif __has_feature(address_sanitizer) +#include <sanitizer/asan_interface.h> +#define GEMMLOWP_MARK_MEMORY_AS_INITIALIZED __asan_unpoison_memory_region +#endif +#endif + +#endif // GEMMLOWP_ALLOW_INLINE_ASM + +// Detect Android. Don't conflate with ARM - we care about tuning +// for non-ARM Android devices too. This can be used in conjunction +// with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs. +#if defined(__ANDROID__) || defined(ANDROID) +#define GEMMLOWP_ANDROID +#endif + +#endif // GEMMLOWP_INTERNAL_DETECT_PLATFORM_H_ diff --git a/internal/dispatch_gemm_shape.h b/internal/dispatch_gemm_shape.h index 0be0bf3..ba4f341 100644 --- a/internal/dispatch_gemm_shape.h +++ b/internal/dispatch_gemm_shape.h @@ -85,6 +85,22 @@ struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> { } }; +template <VectorShape Shape> +struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> { + typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType; + static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value; + typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape> + DstType; + static DstType Run(const SrcType& src) { + DstType dst; + dst.result_fixedpoint_multiplier = + Transpose(src.result_fixedpoint_multiplier); + dst.result_exponent = Transpose(src.result_exponent); + dst.result_offset_after_shift = src.result_offset_after_shift; + return dst; + } +}; + template <typename VectorMapType> struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> { typedef OutputStageBiasAddition<VectorMapType> SrcType; diff --git a/internal/kernel.h b/internal/kernel.h index 825a7f3..3120216 100644 --- a/internal/kernel.h +++ b/internal/kernel.h @@ -145,12 +145,24 @@ struct KernelSideFormat { static const int kCells = tCells; static const int kWidth = kCells * Cell::kWidth; static const int kDepth = Cell::kDepth; - typedef std::uint8_t Scalar; + typedef std::uint8_t Scalar; // The scalar type of the Format. + typedef std::uint8_t InputScalar; // The scalar type of the original input. }; +// KernelSideFormat for int8 fast kernel trick. The original input is uint8, but +// packs converts it to int8. template <typename tCellFormat, int tCells> struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> { typedef std::int8_t Scalar; + typedef std::uint8_t InputScalar; +}; + +// KernelSideFormat for int8 inputs, enabling int8 fast kernel trick without +// pack conversion. +template <typename tCellFormat, int tCells> +struct KernelSideFormatInt8Inputs : KernelSideFormat<tCellFormat, tCells> { + typedef std::int8_t Scalar; + typedef std::int8_t InputScalar; }; // KernelFormat describes fully the input data layout that a kernel expects. @@ -216,19 +228,24 @@ struct KernelBase { virtual ~KernelBase() {} }; -template <typename KernelScalarType> +template <typename InputKernelScalarType, typename KernelScalarType> struct ZeroPointInputValue {}; template <> -struct ZeroPointInputValue<std::uint8_t> { +struct ZeroPointInputValue<std::uint8_t, std::uint8_t> { static constexpr std::uint8_t kValue = 0; }; template <> -struct ZeroPointInputValue<std::int8_t> { +struct ZeroPointInputValue<std::uint8_t, std::int8_t> { static constexpr std::uint8_t kValue = 128; }; +template <> +struct ZeroPointInputValue<std::int8_t, std::int8_t> { + static constexpr std::uint8_t kValue = 0; +}; + } // namespace gemmlowp #endif // GEMMLOWP_INTERNAL_KERNEL_H_ diff --git a/internal/kernel_avx.h b/internal/kernel_avx.h new file mode 100644 index 0000000..2fe1249 --- /dev/null +++ b/internal/kernel_avx.h @@ -0,0 +1,361 @@ +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// kernel_SSE.h: a collection of Intel SSE optimized kernels. +// Check in kernel_default.h which one(s) are actually used by default. +// Others are mere experiments; they are still covered by tests +// in case they might be useful some day. +// + +#ifndef GEMMLOWP_INTERNAL_KERNEL_AVX_H_ +#define GEMMLOWP_INTERNAL_KERNEL_AVX_H_ + +#include "kernel.h" + +#include <string.h> +#include <cassert> + +namespace gemmlowp { + +#ifdef GEMMLOWP_AVX2_64 +struct AVX2_64_Kernel24x8Depth2 : KernelBase { + typedef KernelFormat<KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1>> + Format; + + const char *Name() const override { return "AVX, 24x8, depth 2"; } + + void Run(std::int32_t *dst_ptr, std::size_t dst_row_stride, std::size_t dst_col_stride, + const std::uint8_t *lhs_ptr, const std::uint8_t *rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { + ScopedProfilingLabel label("optimized kernel"); + assert(dst_row_stride == 1); + const std::int64_t run_depth_cells = run_depth / Format::kDepth; + const std::int64_t dst_col_stride_q = dst_col_stride; + + /* Main loop */ + + // A 2x8 cell of Rhs is stored in 16bit in ymm1 . + // A 24x2 block of 3 8x2 cells Lhs is stored in 16bit in ymm0, replaced + // every Iteration. + // A 8x8 block of accumulators is stored in 32bit in xmm4--xmm15. + // + // +-------+-------+-------+-------+ + // |ymm1[0] |ymm2[2] | + // Rhs +-------+---------------+-------+ + // |ymm1[1] |ymm1[4] | + // +-------+-------+-------+-------+ + // + // | | | | | + // + // Lhs | | | | | + // + // +--+--+ - - - - +-------+-------+-------+-------+ + // |ymm0 | | ymm4 | ymm5 | ymm6 | ymm7 | + // |ymm0 | (Iter1) | ymm4 | ymm5 | ymm6 | ymm7 | + // |ymm0 | | ymm4 | ymm5 | ymm6 | ymm7 | + // |ymm0 | | ymm4 | ymm5 | ymm6 | ymm7 | + // +--+--+ - - - - +-------+-------+-------+-------+ + // |ymm0 | | ymm8 | ymm9 | ymm10 | ymm11 | + // |ymm0 | (Iter2) | ymm8 | ymm9 | ymm10 | ymm11 | + // |ymm0 | | ymm8 | ymm9 | ymm10 | ymm11 | + // |ymm0 | | ymm8 | ymm9 | ymm10 | ymm11 | + // +--+--+ - - - - +-------+-------+-------+-------+ + // |ymm0 | | ymm12 | ymm13 | ymm14 | ymm15 | + // |ymm0 | (Iter3) | ymm12 | ymm13 | ymm14 | ymm15 | + // |ymm0 | | ymm12 | ymm13 | ymm14 | ymm15 | + // |ymm0 | | ymm12 | ymm13 | ymm14 | ymm15 | + // +--+--+ - - - - +-------+-------+-------+-------+ + // + // Accumulator + + asm volatile( + // Set registers for destination + "movq %[dst_col_stride_q], %%r12\n\t" // stride is r12 + "shlq $2, %%r12\n\t" // set stride dword + "leaq (%%r12,%%r12,0x2), %%r13\n\t" // load stride aligned r13 + + // Set accumulators to zero. + "vpxor %%ymm4, %%ymm4, %%ymm4 \n\t" // zero accumulators + "vpxor %%ymm5, %%ymm5, %%ymm5 \n\t" // zero accumulators + "vpxor %%ymm6, %%ymm6, %%ymm6 \n\t" // zero accumulators + "vpxor %%ymm7, %%ymm7, %%ymm7 \n\t" // zero accumulators + "vpxor %%ymm8, %%ymm8, %%ymm8 \n\t" // zero accumulators + "vpxor %%ymm9, %%ymm9, %%ymm9 \n\t" // zero accumulators + "vpxor %%ymm10, %%ymm10, %%ymm10\n\t" // zero accumulators + "vpxor %%ymm11, %%ymm11, %%ymm11\n\t" // zero accumulators + "vpxor %%ymm12, %%ymm12, %%ymm12\n\t" // zero accumulators + "vpxor %%ymm13, %%ymm13, %%ymm13\n\t" // zero accumulators + "vpxor %%ymm14, %%ymm14, %%ymm14\n\t" // zero accumulators + "vpxor %%ymm15, %%ymm15, %%ymm15\n\t" // zero accumulators + + "movq %[run_depth_cells], %%r14 \n\t" // load cell depth r14 + "subq $2, %%r14 \n\t" // cell depth is 2 + "js outerLoop1%= \n\t" // outerloop for matrix + + // Loop for K unrolled by 4 + "outerLoop2%=: \n\t" // outer loop unroll + + // K = 0,1,2,3 + // RHS cell to ymm1 + + // lower half + "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1 + "vpermq $0x44,%%ymm1, %%ymm1 \n\t" + // LHS cell elements 0 and 1 + "vpmovzxbw 0x00(%[lhs_ptr]), %%ymm0\n\t" // mov lhs to ymm0 + "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to all ymm2 + "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to all ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs0 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs1 into ymm3 + "vpaddd %%ymm2, %%ymm4, %%ymm4 \n\t" // add muladd lhs + rhs0 into ymm4 + "vpaddd %%ymm3, %%ymm5, %%ymm5 \n\t" // add muladd lhs + rhs1 into ymm5 + // LHS cell elements 2 and 3 + "vpshufd $0xaa, %%ymm1, %%ymm2 \n\t" // move rhs 2 element to all ymm2 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rh3 into ymm2 + "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element into all ymm3 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rh4 into ymm3 + "vpaddd %%ymm2, %%ymm6, %%ymm6 \n\t" // add muladd lhs + rhs2 into ymm6 + "vpaddd %%ymm3, %%ymm7, %%ymm7 \n\t" // add muladd lhs + rhs3 into ymm7 + + // cache prefect lhs //see if it works better? + //"prefetcht0 0x80(%[lhs_ptr]) \n\t" //prefetch cache lines + "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1 + "vpermq $0x44,%%ymm1, %%ymm1 \n\t" + + // K = 5,6,7,8 + // next LHS cell elements 0 and 1 + "vpmovzxbw 0x10(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0 + "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // mov rhs 0 element to all ymm2 + "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // mov rhs 1 element to all ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs0 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs1 into ymm3 + "vpaddd %%ymm2, %%ymm8, %%ymm8 \n\t" // add muladd lhs + rhs0 into ymm8 + "vpaddd %%ymm3, %%ymm9, %%ymm9 \n\t" // add muladd lhs + rhs1 into ymm9 + // next LHS cell elements 2 and 3 + "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to all ymm2 + "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to all ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs3 into ymm3 + "vpaddd %%ymm2, %%ymm10, %%ymm10 \n\t" // add muladd lhs + rhs2 into ymm10 + "vpaddd %%ymm3, %%ymm11, %%ymm11 \n\t" // add muladd lhs + rhs3 into ymm11 + + // rhs lower half + "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1 + "vpermq $0x44,%%ymm1, %%ymm1 \n\t" // duplcate lower 16 + + // next LHS cell elements 0 and 1 + "vpmovzxbw 0x20(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0 + "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // mov rhs 0 element to all ymm2 + "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // mov rhs 1 element to all ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs0 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs1 into ymm3 + "vpaddd %%ymm2, %%ymm12, %%ymm12 \n\t" // add muladd lhs + rhs0 into ymm8 + "vpaddd %%ymm3, %%ymm13, %%ymm13 \n\t" // add muladd lhs + rhs1 into ymm9 + + // cache prefetch rhs //see if it works better? + //"prefetcht0 0x80(%[rhs_ptr]) \n\t" + + // next LHS cell elements 2 and 3 + "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to all ymm2 + "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to all ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mul add lhs rhs3 into ymm3 + "vpaddd %%ymm2, %%ymm14, %%ymm14 \n\t" // add muladd lhs + rhs2 into ymm10 + "vpaddd %%ymm3, %%ymm15, %%ymm15 \n\t" // add muladd lhs + rhs3 into ymm11 + + // current result in ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 + + // rhs+10 lower half + "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1 + "vpermq $0x44,%%ymm1, %%ymm1 \n\t" + // next LHS cell elements 0 and 1 + "vpmovzxbw 0x30(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0 + "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to ymm2 + "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs0 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs1 into ymm3 + "vpaddd %%ymm2, %%ymm4, %%ymm4 \n\t" // accumulate to ymm4 + "vpaddd %%ymm3, %%ymm5, %%ymm5 \n\t" // accumulate to ymm5 + // next LHS cell elements 2 and 3 + "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to ymm2 + "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to ymm2 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mull add lhs rhs3 into ymm3 + "vpaddd %%ymm2, %%ymm6, %%ymm6 \n\t" // add lhs rhs2 to ymm6 + "vpaddd %%ymm3, %%ymm7, %%ymm7 \n\t" // add lhs rhs3 to ymm7 + + // rhs+10 lower half + "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1 + "vpermq $0x44,%%ymm1, %%ymm1 \n\t" + + // next LHS cell elements 4 and 5 + "vpmovzxbw 0x40(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0 + "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to ymm2 + "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs0 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs1 into ymm3 + "vpaddd %%ymm2, %%ymm8, %%ymm8 \n\t" // accumulate to ymm8 + "vpaddd %%ymm3, %%ymm9, %%ymm9 \n\t" // accumulate to ymm9 + // next LHS cell elements 6 and 7 + "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to ymm2 + "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to ymm2 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mull add lhs rhs3 into ymm3 + "vpaddd %%ymm2, %%ymm10, %%ymm10 \n\t" // add lhs rhs2 to ymm10 + "vpaddd %%ymm3, %%ymm11, %%ymm11 \n\t" // add lhs rhs3 to ymm11 + + "vpmovzxbw 0x08(%[rhs_ptr]), %%ymm1 \n\t" // mov rhs to ymm1 + "vpermq $0x44,%%ymm1, %%ymm1 \n\t" + // next LHS cell elements 9 and 10 + "vpmovzxbw 0x50(%[lhs_ptr]), %%ymm0 \n\t" // mov lhs to ymm0 + "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // move rhs 0 element to ymm2 + "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // move rhs 1 element to ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs0 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs1 into ymm3 + "vpaddd %%ymm2, %%ymm12, %%ymm12 \n\t" // accumulate to ymm12 + "vpaddd %%ymm3, %%ymm13, %%ymm13 \n\t" // accumulate to ymm13 + + // next LHS cell elements 11 and 12 + "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // mov rhs 2 element to ymm2 + "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // mov rhs 3 element to ymm2 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // mul add lhs rhs2 into ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // mull add lhs rhs3 into ymm3 + "vpaddd %%ymm2, %%ymm14, %%ymm14 \n\t" // add lhs rhs2 to ymm14 + "vpaddd %%ymm3, %%ymm15, %%ymm15 \n\t" // add lhs rhs3 to ymm15 + + // completed rhs+10 + "addq $0x60, %[lhs_ptr] \n\t" // increment stride lhs + "addq $0x10, %[rhs_ptr] \n\t" // increment stride rhs + + "subq $2, %[run_depth_cells] \n\t" + "ja outerLoop2%= \n\t" + + "movq %[run_depth_cells], %%r14 \n\t" + "decq %%r14 \n\t" + "js finish%= \n\t" + + // Loop for K unrolled by 2 + "outerLoop1%=: \n\t" + + // rhs lower + "vpmovzxbw (%[rhs_ptr]), %%ymm1 \n\t" // get rhs into ymm1 + "vpermq $0x44,%%ymm1, %%ymm1 \n\t" + + // LHS cell + "vpmovzxbw (%[lhs_ptr]), %%ymm0 \n\t" // lhs in into ymm0 + "vpshufd $0x00,%%ymm1,%%ymm2 \n\t" // rhs element 0 into ymm2 + "vpshufd $0x55,%%ymm1,%%ymm3 \n\t" // rhs element 1 into ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 0 ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 1 ymm3 + "vpaddd %%ymm2, %%ymm4, %%ymm4 \n\t" // acc element 0 ymm4 + "vpaddd %%ymm3, %%ymm5, %%ymm5 \n\t" // acc element 1 ymm5 + "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // rhs element 2 into ymm2 + "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // rhs element 3 into ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 2 ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 3 ymm3 + "vpaddd %%ymm2, %%ymm6, %%ymm6 \n\t" // acc element 2 into ymm6 + "vpaddd %%ymm3, %%ymm7, %%ymm7 \n\t" // acc element 3 into ymm7 + + // lhs+10 + "vpmovzxbw 0x10(%[lhs_ptr]), %%ymm0 \n\t" // lhs in into ymm0 + "vpshufd $0x00, %%ymm1, %%ymm2 \n\t" // rhs element 0 into ymm2 + "vpshufd $0x55, %%ymm1, %%ymm3 \n\t" // rhs element 1 into ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 0 ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 1 ymm3 + "vpaddd %%ymm2, %%ymm8, %%ymm8 \n\t" // acc element 0 ymm8 + "vpaddd %%ymm3, %%ymm9, %%ymm9 \n\t" // acc element 1 ymm9 + "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // rhs element 2 into ymm2 + "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // rhs element 3 into ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 2 ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 3 ymm3 + "vpaddd %%ymm2, %%ymm10, %%ymm10 \n\t" // acc element 2 into ymm10 + "vpaddd %%ymm3, %%ymm11, %%ymm11 \n\t" // acc element 3 into ymm11 + + "vpmovzxbw 0x20(%[lhs_ptr]), %%ymm0 \n\t" + "vpshufd $0x00, %%ymm1, %%ymm2 \n\t" // rhs element 0 into ymm2 + "vpshufd $0x55, %%ymm1, %%ymm3 \n\t" // rhs element 1 into ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 0 ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 1 ymm3 + "vpaddd %%ymm2, %%ymm12, %%ymm12 \n\t" // acc element 0 ymm12 + "vpaddd %%ymm3, %%ymm13, %%ymm13 \n\t" // acc element 1 ymm13 + "vpshufd $0xaa,%%ymm1,%%ymm2 \n\t" // rhs element 2 into ymm2 + "vpshufd $0xff,%%ymm1,%%ymm3 \n\t" // rhs element 3 into ymm3 + "vpmaddwd %%ymm0, %%ymm2, %%ymm2 \n\t" // muladd lhs rhs element 2 ymm2 + "vpmaddwd %%ymm0, %%ymm3, %%ymm3 \n\t" // muladd lhs rhs element 3 ymm3 + "vpaddd %%ymm2, %%ymm14, %%ymm14 \n\t" // acc element 2 into ymm14 + "vpaddd %%ymm3, %%ymm15, %%ymm15 \n\t" // acc element 3 into ymm15 + + // update matrix pointers + "addq $0x30, %[lhs_ptr] \n\t" + "addq $0x08, %[rhs_ptr] \n\t" + + "decq %[run_depth_cells] \n\t" + "jnz outerLoop1%= \n\t" + + "finish%=:\n\t" + + "test %[start_depth], %[start_depth] \n\t" + "jz storeDst%= \n\t" + + "vpaddd 0x00(%[dst_ptr]), %%ymm4, %%ymm4 \n\t" // rhs0 + "vpaddd 0x20(%[dst_ptr]), %%ymm8, %%ymm8 \n\t" // rhs0 + "vpaddd 0x40(%[dst_ptr]), %%ymm12, %%ymm12 \n\t" // rhs0 + + "vpaddd 0x00(%[dst_ptr], %%r12, 1) , %%ymm5, %%ymm5 \n\t" // rhs1 + "vpaddd 0x20(%[dst_ptr], %%r12, 1) , %%ymm9, %%ymm9 \n\t" // rhs1 + "vpaddd 0x40(%[dst_ptr], %%r12, 1) , %%ymm13, %%ymm13 \n\t" // rhs1 + + "vpaddd 0x00(%[dst_ptr], %%r12, 2) , %%ymm6, %%ymm6 \n\t" // rhs2 + "vpaddd 0x20(%[dst_ptr], %%r12, 2) , %%ymm10, %%ymm10 \n\t" // rhs2 + "vpaddd 0x40(%[dst_ptr], %%r12, 2) , %%ymm14, %%ymm14 \n\t" // rhs2 + + "vpaddd 0x00(%[dst_ptr], %%r13, 1) , %%ymm7, %%ymm7 \n\t" // rhs3 + "vpaddd 0x20(%[dst_ptr], %%r13, 1) , %%ymm11, %%ymm11 \n\t" // rhs3 + "vpaddd 0x40(%[dst_ptr], %%r13, 1) , %%ymm15, %%ymm15 \n\t" // rhs3 + + "storeDst%=:\n\t" + + "vmovdqu %%ymm4, 0x00(%[dst_ptr]) \n\t" // rhs0 + "vmovdqu %%ymm8, 0x20(%[dst_ptr]) \n\t" // rhs0 + "vmovdqu %%ymm12, 0x40(%[dst_ptr]) \n\t" // rhs0 + + "vmovdqu %%ymm5, 0x00(%[dst_ptr], %%r12, 1) \n\t" // rhs1 + "vmovdqu %%ymm9, 0x20(%[dst_ptr], %%r12, 1) \n\t" // rhs1 + "vmovdqu %%ymm13, 0x40(%[dst_ptr], %%r12, 1) \n\t" // rhs1 + + "vmovdqu %%ymm6, 0x00(%[dst_ptr], %%r12, 2) \n\t" // rhs2 + "vmovdqu %%ymm10, 0x20(%[dst_ptr], %%r12, 2) \n\t" // rhs2 + "vmovdqu %%ymm14, 0x40(%[dst_ptr], %%r12, 2) \n\t" // rhs2 + + "vmovdqu %%ymm7, 0x00(%[dst_ptr], %%r13, 1) \n\t" // rhs3 + "vmovdqu %%ymm11, 0x20(%[dst_ptr], %%r13, 1) \n\t" // rhs3 + "vmovdqu %%ymm15, 0x40(%[dst_ptr], %%r13, 1) \n\t" // rhs3 + + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_ptr] "+r"(dst_ptr) + : // inputs + [start_depth] "r"(start_depth), [dst_col_stride_q] "r"(dst_col_stride_q), + [run_depth_cells] "r"(run_depth_cells) + : // clobbers + "cc", "memory", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", + "%ymm8", "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15", "%r12", + "%r13", "%r14"); + } +}; +#endif + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_KERNEL_AVX_H_ diff --git a/internal/kernel_default.h b/internal/kernel_default.h index a919ffe..29b0991 100644 --- a/internal/kernel_default.h +++ b/internal/kernel_default.h @@ -20,66 +20,84 @@ #include "../public/bit_depth.h" #include "common.h" +#include "kernel.h" #include "kernel_reference.h" namespace gemmlowp { -template <bool MaxProductIsLessThan4096, bool LhsAlwaysNonzero> +template <bool MaxProductIsLessThan4096, bool IsUnsigned, bool LhsNonZero> struct DefaultKernelImpl {}; // Partial specialization implementing the logic that if we want to use -// a kernel for LhsAlwaysNonzero but do not have such a kernel, then we fall -// back to a generic kernel not taking advantage of LhsAlwaysNonzero. -template <bool LhsAlwaysNonzero> -struct DefaultKernelImpl<true, LhsAlwaysNonzero> - : DefaultKernelImpl<false, LhsAlwaysNonzero> {}; - -// Partial specialization implementing the logic that if we want to use // a kernel for MaxProductIsLessThan4096 but do not have such a kernel, then we // fall back to a generic kernel not taking advantage of // MaxProductIsLessThan4096. +template <bool LhsNonZero> +struct DefaultKernelImpl<true, true, LhsNonZero> + : DefaultKernelImpl<false, true, LhsNonZero> {}; + +// Partial specialization implementing the logic that if we want to use +// a kernel for LhsNonZero but do not have such a kernel, then we fall +// back to a generic kernel not taking advantage of LhsNonZero. template <bool MaxProductIsLessThan4096> -struct DefaultKernelImpl<MaxProductIsLessThan4096, true> - : DefaultKernelImpl<MaxProductIsLessThan4096, false> {}; +struct DefaultKernelImpl<MaxProductIsLessThan4096, true, true> + : DefaultKernelImpl<MaxProductIsLessThan4096, true, false> {}; template <typename BitDepthParams> struct DefaultKernel : DefaultKernelImpl<(BitDepthParams::LhsRange::kMaxValue * BitDepthParams::RhsRange::kMaxValue < 4096), - (BitDepthParams::LhsRange::kMinValue > 0)> {}; + (BitDepthParams::LhsRange::kMinValue >= 0), + (BitDepthParams::LhsRange::kMinValue > 0 || + (BitDepthParams::LhsRange::kMaxValue <= 127 && + BitDepthParams::LhsRange::kMinValue > -128))> {}; } // end namespace gemmlowp -#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \ - LhsAlwaysNonzero, Kernel) \ - namespace gemmlowp { \ - template <> \ - struct DefaultKernelImpl<MaxProductIsLessThan4096, LhsAlwaysNonzero> \ - : Kernel {}; \ +#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, IsUnsigned, \ + LhsAlwaysNonZero, Kernel) \ + namespace gemmlowp { \ + template <> \ + struct DefaultKernelImpl<MaxProductIsLessThan4096, IsUnsigned, \ + LhsAlwaysNonZero> : Kernel {}; \ } +// User-provided int8 inputs is only supported in the NEON path currently. #if defined GEMMLOWP_NEON_32 #include "kernel_neon.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_32_Kernel12x4Depth2) -GEMMLOWP_SET_DEFAULT_KERNEL(true, false, +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_32_Kernel12x4Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(true, true, false, NEON_32_Kernel12x4Depth2Assuming12BitProducts) -GEMMLOWP_SET_DEFAULT_KERNEL(false, true, +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, NEON_32bit_GEMM_Int8Operands_LhsNonzero) +GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true, + NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs) #elif defined GEMMLOWP_NEON_64 #include "kernel_neon.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2) -GEMMLOWP_SET_DEFAULT_KERNEL(false, true, +#if defined GEMMLOWP_DOTPROD_KERNEL +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, + NEON_64_Kernel12x8Depth4_dotprod) +#else +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, NEON_64_Kernel12x8Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, NEON_64bit_GEMM_Int8Operands_LhsNonzero) +#endif +GEMMLOWP_SET_DEFAULT_KERNEL(false, false, true, + NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs) #elif defined(GEMMLOWP_MSA) #include "kernel_msa.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, MSA_Kernel12x8Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, MSA_Kernel12x8Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, true, MSA_GEMM_Int8Operands_LhsNonzero) #elif defined GEMMLOWP_SSE4_32 #include "kernel_sse.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_32_Kernel4x4Depth2) #elif defined GEMMLOWP_SSE4_64 #include "kernel_sse.h" -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, SSE4_64_Kernel12x4Depth2) +#elif defined GEMMLOWP_AVX2_64 +#include "kernel_avx.h" +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, AVX2_64_Kernel24x8Depth2) #else #include "kernel_reference.h" namespace gemmlowp { @@ -88,7 +106,7 @@ typedef ReferenceKernel<KernelFormat< KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > > DefaultReferenceKernel; } -GEMMLOWP_SET_DEFAULT_KERNEL(false, false, DefaultReferenceKernel) +GEMMLOWP_SET_DEFAULT_KERNEL(false, true, false, DefaultReferenceKernel) #endif #endif // GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_ diff --git a/internal/kernel_msa.h b/internal/kernel_msa.h index 4985b73..a9205f6 100644 --- a/internal/kernel_msa.h +++ b/internal/kernel_msa.h @@ -42,8 +42,8 @@ namespace gemmlowp { // Our main GEMM kernel. struct MSA_Kernel12x8Depth2 : KernelBase { - typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>, - KernelSideFormat<CellFormat<4, 2>, 2> > + typedef KernelFormat<KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>, + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 2> > Format; const char* Name() const override { return "MSA, 12x8, depth 2"; } @@ -62,9 +62,6 @@ struct MSA_Kernel12x8Depth2 : KernelBase { assert(dst_row_stride == 1); asm volatile( - // Set a temp to all zeroes. - "ldi.b $w31, 0\n" - // Multiply dst_col_stride by 4 == sizeof(int32) to use // it as a byte offset below. GEMMLOWP_MIPS_XSLL @@ -75,32 +72,25 @@ struct MSA_Kernel12x8Depth2 : KernelBase { "beqz %[start_depth], " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n" // Load accumulators (start_depth != 0). - GEMMLOWP_MIPS_XADDU - " $a0, %[dst_ptr], %[dst_col_stride]\n" + GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n" "ld.w $w0, (0*16)(%[dst_ptr])\n" "ld.w $w4, (1*16)(%[dst_ptr])\n" - "ld.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU - " $a1, $a0, %[dst_col_stride]\n" + "ld.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n" "ld.w $w1, (0*16)($a0)\n" "ld.w $w5, (1*16)($a0)\n" - "ld.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU - " $a0, $a1, %[dst_col_stride]\n" + "ld.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n" "ld.w $w2, (0*16)($a1)\n" "ld.w $w6, (1*16)($a1)\n" - "ld.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU - " $a1, $a0, %[dst_col_stride]\n" + "ld.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n" "ld.w $w3, (0*16)($a0)\n" "ld.w $w7, (1*16)($a0)\n" - "ld.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU - " $a0, $a1, %[dst_col_stride]\n" + "ld.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n" "ld.w $w12, (0*16)($a1)\n" "ld.w $w16, (1*16)($a1)\n" - "ld.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU - " $a1, $a0, %[dst_col_stride]\n" + "ld.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n" "ld.w $w13, (0*16)($a0)\n" "ld.w $w17, (1*16)($a0)\n" - "ld.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU - " $a0, $a1, %[dst_col_stride]\n" + "ld.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n" "ld.w $w14, (0*16)($a1)\n" "ld.w $w18, (1*16)($a1)\n" "ld.w $w22, (2*16)($a1)\n" @@ -109,8 +99,7 @@ struct MSA_Kernel12x8Depth2 : KernelBase { "ld.w $w23, (2*16)($a0)\n" "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n" - GEMMLOWP_LABEL_CLEAR_ACCUMULATORS - ":\n" + GEMMLOWP_LABEL_CLEAR_ACCUMULATORS ":\n" // Clear accumulators (start_depth == 0). "ldi.w $w0, 0\n" "ldi.w $w4, 0\n" @@ -139,17 +128,16 @@ struct MSA_Kernel12x8Depth2 : KernelBase { GEMMLOWP_LABEL_BEFORE_LOOP ":\n" - GEMMLOWP_LABEL_LOOP - ":\n" + GEMMLOWP_LABEL_LOOP ":\n" // Overview of register layout: // - // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30 + // A half of the 2 2x4 cells of Rhs is stored in 16bit in w28-w31 // (each register contains 4 replicas of a pair of elements). // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26. // A 12x8 block of accumulators is stored in 32bit in w0-w23. // // +------+------+------+------+ - // Rhs |w27 |w28 |w29 |w30 | + // Rhs |w28 |w29 |w30 |w31 | // +------+------+------+------+ // // | | | | | @@ -179,128 +167,86 @@ struct MSA_Kernel12x8Depth2 : KernelBase { "ld.b $w24, 0(%[lhs_ptr])\n" "ld.b $w25, 8(%[lhs_ptr])\n" - // Load 4 bytes of rhs[] for the first half of depth 0. - "lbu $a0, 0(%[rhs_ptr])\n" - "lbu $a1, 1(%[rhs_ptr])\n" - "lbu $a2, 2(%[rhs_ptr])\n" - "lbu $a3, 3(%[rhs_ptr])\n" - // Load 4 bytes of rhs[] for the first half of depth 1. - "lbu $v0, 4(%[rhs_ptr])\n" - "lbu $v1, 5(%[rhs_ptr])\n" - "lbu $t8, 6(%[rhs_ptr])\n" - "lbu $t9, 7(%[rhs_ptr])\n" + // Load 2 x 8 bytes of rhs[]. + "ld.b $w27, 0(%[rhs_ptr])\n" // Zero-extend 8-bit elements of lhs[] to 16 bits. + "ldi.b $w31, 0\n" "ilvr.b $w24, $w31, $w24\n" "ilvl.b $w26, $w31, $w25\n" "ilvr.b $w25, $w31, $w25\n" - // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. - "ilvl.d $w27, $w31, $w24\n" - "ilvl.d $w28, $w31, $w25\n" - "ilvl.d $w29, $w31, $w26\n" - "ilvr.h $w24, $w27, $w24\n" - "ilvr.h $w25, $w28, $w25\n" - "ilvr.h $w26, $w29, $w26\n" - - // Combine and interleave depth 0 and depth 1 elements of rhs[] for - // dpadd_u.w (for the first half). - "ins $a0, $v0, 16, 8\n" - "ins $a1, $v1, 16, 8\n" - "ins $a2, $t8, 16, 8\n" - "ins $a3, $t9, 16, 8\n" - // Make 4 replicas of every pair of rhs[] elements. - "fill.w $w27, $a0\n" - "fill.w $w28, $a1\n" - "fill.w $w29, $a2\n" - "fill.w $w30, $a3\n" - - // Load 4 bytes of rhs[] for the second half of depth 0. - "lbu $a0, 8(%[rhs_ptr])\n" - "lbu $a1, 9(%[rhs_ptr])\n" - "lbu $a2, 10(%[rhs_ptr])\n" - "lbu $a3, 11(%[rhs_ptr])\n" - // Load 4 bytes of rhs[] for the second half of depth 1. - "lbu $v0, 12(%[rhs_ptr])\n" - "lbu $v1, 13(%[rhs_ptr])\n" - "lbu $t8, 14(%[rhs_ptr])\n" - "lbu $t9, 15(%[rhs_ptr])\n" // First half of depths 0 and 1. - // Dot-product-(and)-add doubles multiplicand width. - "dpadd_u.w $w0, $w24, $w27\n" - "dpadd_u.w $w4, $w25, $w27\n" - "dpadd_u.w $w8, $w26, $w27\n" - "dpadd_u.w $w1, $w24, $w28\n" - "dpadd_u.w $w5, $w25, $w28\n" - "dpadd_u.w $w9, $w26, $w28\n" - "dpadd_u.w $w2, $w24, $w29\n" - "dpadd_u.w $w6, $w25, $w29\n" - "dpadd_u.w $w10, $w26, $w29\n" - "dpadd_u.w $w3, $w24, $w30\n" - "dpadd_u.w $w7, $w25, $w30\n" - "dpadd_u.w $w11, $w26, $w30\n" - - // Combine and interleave depth 0 and depth 1 elements of rhs[] for - // dpadd_u.w (for the second half). - "ins $a0, $v0, 16, 8\n" - "ins $a1, $v1, 16, 8\n" - "ins $a2, $t8, 16, 8\n" - "ins $a3, $t9, 16, 8\n" + // Zero-extend 8-bit elements of rhs[] to 16 bits. + "ilvr.b $w31, $w31, $w27\n" // Make 4 replicas of every pair of rhs[] elements. - "fill.w $w27, $a0\n" - "fill.w $w28, $a1\n" - "fill.w $w29, $a2\n" - "fill.w $w30, $a3\n" + "splati.w $w28, $w31[0]\n" + "splati.w $w29, $w31[1]\n" + "splati.w $w30, $w31[2]\n" + "splati.w $w31, $w31[3]\n" + // Dot-product-(and)-add doubles multiplicand width. + "dpadd_u.w $w0, $w24, $w28\n" + "dpadd_u.w $w4, $w25, $w28\n" + "dpadd_u.w $w8, $w26, $w28\n" + "dpadd_u.w $w1, $w24, $w29\n" + "dpadd_u.w $w5, $w25, $w29\n" + "dpadd_u.w $w9, $w26, $w29\n" + "dpadd_u.w $w2, $w24, $w30\n" + "dpadd_u.w $w6, $w25, $w30\n" + "dpadd_u.w $w10, $w26, $w30\n" + "dpadd_u.w $w3, $w24, $w31\n" + "dpadd_u.w $w7, $w25, $w31\n" + "dpadd_u.w $w11, $w26, $w31\n" // Second half of depths 0 and 1. + // Zero-extend 8-bit elements of rhs[] to 16 bits. + "ldi.b $w31, 0\n" + "ilvl.b $w31, $w31, $w27\n" + // Make 4 replicas of every pair of rhs[] elements. + "splati.w $w28, $w31[0]\n" + "splati.w $w29, $w31[1]\n" + "splati.w $w30, $w31[2]\n" + "splati.w $w31, $w31[3]\n" // Dot-product-(and)-add doubles multiplicand width. - "dpadd_u.w $w12, $w24, $w27\n" - "dpadd_u.w $w16, $w25, $w27\n" - "dpadd_u.w $w20, $w26, $w27\n" - "dpadd_u.w $w13, $w24, $w28\n" - "dpadd_u.w $w17, $w25, $w28\n" - "dpadd_u.w $w21, $w26, $w28\n" - "dpadd_u.w $w14, $w24, $w29\n" - "dpadd_u.w $w18, $w25, $w29\n" - "dpadd_u.w $w22, $w26, $w29\n" - "dpadd_u.w $w15, $w24, $w30\n" - "dpadd_u.w $w19, $w25, $w30\n" - "dpadd_u.w $w23, $w26, $w30\n" + "dpadd_u.w $w12, $w24, $w28\n" + "dpadd_u.w $w16, $w25, $w28\n" + "dpadd_u.w $w20, $w26, $w28\n" + "dpadd_u.w $w13, $w24, $w29\n" + "dpadd_u.w $w17, $w25, $w29\n" + "dpadd_u.w $w21, $w26, $w29\n" + "dpadd_u.w $w14, $w24, $w30\n" + "dpadd_u.w $w18, $w25, $w30\n" + "dpadd_u.w $w22, $w26, $w30\n" + "dpadd_u.w $w15, $w24, $w31\n" + "dpadd_u.w $w19, $w25, $w31\n" + "dpadd_u.w $w23, $w26, $w31\n" GEMMLOWP_MIPS_XADDIU " %[run_depth], -2\n" GEMMLOWP_MIPS_XADDIU - " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU - " %[rhs_ptr], 16\n" + " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n" "bnez %[run_depth]," GEMMLOWP_LABEL_LOOP "b\n" GEMMLOWP_LABEL_AFTER_LOOP ":\n" // Store accumulators. - GEMMLOWP_MIPS_XADDU - " $a0, %[dst_ptr], %[dst_col_stride]\n" + GEMMLOWP_MIPS_XADDU " $a0, %[dst_ptr], %[dst_col_stride]\n" "st.w $w0, (0*16)(%[dst_ptr])\n" "st.w $w4, (1*16)(%[dst_ptr])\n" - "st.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU - " $a1, $a0, %[dst_col_stride]\n" + "st.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n" "st.w $w1, (0*16)($a0)\n" "st.w $w5, (1*16)($a0)\n" - "st.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU - " $a0, $a1, %[dst_col_stride]\n" + "st.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n" "st.w $w2, (0*16)($a1)\n" "st.w $w6, (1*16)($a1)\n" - "st.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU - " $a1, $a0, %[dst_col_stride]\n" + "st.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n" "st.w $w3, (0*16)($a0)\n" "st.w $w7, (1*16)($a0)\n" - "st.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU - " $a0, $a1, %[dst_col_stride]\n" + "st.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n" "st.w $w12, (0*16)($a1)\n" "st.w $w16, (1*16)($a1)\n" - "st.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU - " $a1, $a0, %[dst_col_stride]\n" + "st.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU " $a1, $a0, %[dst_col_stride]\n" "st.w $w13, (0*16)($a0)\n" "st.w $w17, (1*16)($a0)\n" - "st.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU - " $a0, $a1, %[dst_col_stride]\n" + "st.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU " $a0, $a1, %[dst_col_stride]\n" "st.w $w14, (0*16)($a1)\n" "st.w $w18, (1*16)($a1)\n" "st.w $w22, (2*16)($a1)\n" @@ -308,18 +254,15 @@ struct MSA_Kernel12x8Depth2 : KernelBase { "st.w $w19, (1*16)($a0)\n" "st.w $w23, (2*16)($a0)\n" : // outputs - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [run_depth] "+r"(run_depth), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [run_depth] "+r"(run_depth), [dst_col_stride] "+r"(dst_col_stride) : // inputs [dst_ptr] "r"(dst_ptr), [start_depth] "r"(start_depth) : // clobbers - "memory", "v0", "v1", "a0", "a1", "a2", "a3", "t8", "t9", "$f0", "$f1", - "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", "$f10", "$f11", - "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20", - "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", - "$f30", "$f31"); + "memory", "a0", "a1", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", + "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20", + "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31"); #undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS #undef GEMMLOWP_LABEL_BEFORE_LOOP @@ -328,6 +271,303 @@ struct MSA_Kernel12x8Depth2 : KernelBase { } }; +// Fast kernel operating on int8 operands. +// It is assumed that one of the two int8 operands only takes values +// in [-127, 127], while the other may freely range in [-128, 127]. +// The issue with both operands taking the value -128 is that: +// -128*-128 + -128*-128 == -32768 overflows int16. +// Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 +// range. That is the basic idea of this kernel. +struct MSA_GEMM_Int8Operands_LhsNonzero : KernelBase { + typedef KernelFormat< + KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormatInt8<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > + Format; + + const char* Name() const override { + return "MSA, 4x4, depth 16, accumulating two within signed int16"; + } + + // TODO(benoitjacob): reorder function arguments so dst comes last + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, + std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, + const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t run_depth) const override { + (void)dst_row_stride; +#define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1" +#define GEMMLOWP_LABEL_LOOP "2" +#define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3" +#define GEMMLOWP_LABEL_STORE "4" + asm volatile( + GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n" + // Load lhs[] and rhs[], zero out internal accumulators. + "ld.b $w16, 0(%[lhs_ptr])\n" + "ldi.b $w0, 0\n" + "ld.b $w20, 0(%[rhs_ptr])\n" + "ldi.b $w1, 0\n" + "ld.b $w17, 16(%[lhs_ptr])\n" + "ldi.b $w2, 0\n" + "ld.b $w21, 16(%[rhs_ptr])\n" + "ldi.b $w3, 0\n" + "ld.b $w18, 32(%[lhs_ptr])\n" + "ldi.b $w4, 0\n" + "ld.b $w19, 48(%[lhs_ptr])\n" + "ldi.b $w5, 0\n" + "ld.b $w22, 32(%[rhs_ptr])\n" + "ldi.b $w6, 0\n" + "ld.b $w23, 48(%[rhs_ptr])\n" + "ldi.b $w7, 0\n" + "ldi.b $w8, 0\n" + "ldi.b $w9, 0\n" + "ldi.b $w10, 0\n" + "ldi.b $w11, 0\n" + "ldi.b $w12, 0\n" + "ldi.b $w13, 0\n" + "ldi.b $w14, 0\n" + "ldi.b $w15, 0\n" + "ldi.h $w31, 1\n" + // If the loop depth is only 16, then we can skip the general loop + // and go straight to the final part of the code. + "beqz %[run_depth], " GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "f\n" + + GEMMLOWP_LABEL_LOOP ":\n" + // Overview of register layout: + // + // A 4x16 block of Rhs is stored in 8 bit in w16-w19. + // A 4x16 block of Lhs is stored in 8 bit in w20-w23. + // + // A 4x4 block of accumulators is stored in w0-w15 (as 4x32 bit + // components which need to be horizontally added at the end). + // + // Dot products of Lhs and Rhs are 16-bit values, which can't + // immediately be accumulated in 32-bit accumulators by that + // same instruction that calculates them. + // For example, "dotp_s.h $w25, $w16, $w20" produces 8 16-bit + // sums in w25 (note, the 16 sums have already been reduced to 8 + // by the horizontal addition of the dotp instruction). + // They are then sign-extended to 32 bits, horizontally added + // (again) to form 4 32-bit sums and then they are finally added + // to the 32-bit accumulators, all by "dpadd_s.w $w0, $w25, $w31". + // + // +-----+-----+-----+-----+ + // Rhs | w20 | w21 | w22 | w23 | + // +-----+-----+-----+-----+ + // + // | | | | | + // + // Lhs | | | | | + // + // +---+ - - - - +-----+-----+-----+-----+ + // |w16| | w0 | w4 | w8 | w12 | + // |w17| | w1 | w5 | w9 | w13 | + // |w18| | w2 | w6 | w10 | w14 | + // |w19| | w3 | w7 | w11 | w15 | + // +---+ - - - - +-----+-----+-----+-----+ + // + // Accumulators + + // Calculate the results for 16 depths and load + // lhs[] and rhs[] for the next iteration. + GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 64\n" + GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 64\n" + GEMMLOWP_MIPS_XADDIU " %[run_depth], -16\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w16, $w20\n" + "dotp_s.h $w26, $w17, $w20\n" + "dotp_s.h $w27, $w16, $w21\n" + "dotp_s.h $w28, $w17, $w21\n" + "dotp_s.h $w29, $w18, $w20\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w0, $w25, $w31\n" + "dpadd_s.w $w1, $w26, $w31\n" + "dpadd_s.w $w4, $w27, $w31\n" + "dpadd_s.w $w5, $w28, $w31\n" + "dpadd_s.w $w2, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w24, $w16, $w22\n" + "dotp_s.h $w25, $w19, $w20\n" + "dotp_s.h $w26, $w16, $w23\n" + "dotp_s.h $w27, $w17, $w22\n" + "ld.b $w20, 0(%[rhs_ptr])\n" + "dotp_s.h $w28, $w17, $w23\n" + "ld.b $w16, 0(%[lhs_ptr])\n" + "dotp_s.h $w29, $w18, $w21\n" + "ld.b $w17, 16(%[lhs_ptr])\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w8, $w24, $w31\n" + "dpadd_s.w $w3, $w25, $w31\n" + "dpadd_s.w $w12, $w26, $w31\n" + "dpadd_s.w $w9, $w27, $w31\n" + "dpadd_s.w $w13, $w28, $w31\n" + "dpadd_s.w $w6, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w19, $w21\n" + "dotp_s.h $w26, $w18, $w22\n" + "dotp_s.h $w27, $w18, $w23\n" + "ld.b $w21, 16(%[rhs_ptr])\n" + "dotp_s.h $w28, $w19, $w22\n" + "ld.b $w18, 32(%[lhs_ptr])\n" + "dotp_s.h $w29, $w19, $w23\n" + "ld.b $w22, 32(%[rhs_ptr])\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w7, $w25, $w31\n" + "ld.b $w19, 48(%[lhs_ptr])\n" + "dpadd_s.w $w10, $w26, $w31\n" + "ld.b $w23, 48(%[rhs_ptr])\n" + "dpadd_s.w $w14, $w27, $w31\n" + "dpadd_s.w $w11, $w28, $w31\n" + "dpadd_s.w $w15, $w29, $w31\n" + + "bnez %[run_depth], " GEMMLOWP_LABEL_LOOP "b\n" + + GEMMLOWP_LABEL_AFTER_LOOP_LAST16 ":\n" + // Calculate the results for the last 16 depths. + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w16, $w20\n" + "dotp_s.h $w26, $w17, $w20\n" + "dotp_s.h $w27, $w16, $w21\n" + "dotp_s.h $w28, $w17, $w21\n" + "dotp_s.h $w29, $w18, $w20\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w0, $w25, $w31\n" + "dpadd_s.w $w1, $w26, $w31\n" + "dpadd_s.w $w4, $w27, $w31\n" + "dpadd_s.w $w5, $w28, $w31\n" + "dpadd_s.w $w2, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w24, $w16, $w22\n" + "dotp_s.h $w25, $w19, $w20\n" + "dotp_s.h $w26, $w16, $w23\n" + "dotp_s.h $w27, $w17, $w22\n" + "dotp_s.h $w28, $w17, $w23\n" + "dotp_s.h $w29, $w18, $w21\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w8, $w24, $w31\n" + "dpadd_s.w $w3, $w25, $w31\n" + "dpadd_s.w $w12, $w26, $w31\n" + "dpadd_s.w $w9, $w27, $w31\n" + "dpadd_s.w $w13, $w28, $w31\n" + "dpadd_s.w $w6, $w29, $w31\n" + + // Dot product: multiply-add pairs of adjacent int8 elements. + // Each dot product takes 16*2 int8 values in and produces 8 int16 sums. + "dotp_s.h $w25, $w19, $w21\n" + "dotp_s.h $w26, $w18, $w22\n" + "dotp_s.h $w27, $w18, $w23\n" + "dotp_s.h $w28, $w19, $w22\n" + "dotp_s.h $w29, $w19, $w23\n" + // Horizontal add of pairs of adjacent int16 sums into internal int32 + // accumulators. + "dpadd_s.w $w7, $w25, $w31\n" + "dpadd_s.w $w10, $w26, $w31\n" + "dpadd_s.w $w14, $w27, $w31\n" + "dpadd_s.w $w11, $w28, $w31\n" + "dpadd_s.w $w15, $w29, $w31\n" + + // Horizontal-add internal accumulators. + "hadd_s.d $w0, $w0, $w0\n" + "hadd_s.d $w1, $w1, $w1\n" + "hadd_s.d $w2, $w2, $w2\n" + "hadd_s.d $w3, $w3, $w3\n" + "hadd_s.d $w4, $w4, $w4\n" + "hadd_s.d $w5, $w5, $w5\n" + "hadd_s.d $w6, $w6, $w6\n" + "hadd_s.d $w7, $w7, $w7\n" + "hadd_s.d $w8, $w8, $w8\n" + "hadd_s.d $w9, $w9, $w9\n" + "hadd_s.d $w10, $w10, $w10\n" + "hadd_s.d $w11, $w11, $w11\n" + "hadd_s.d $w12, $w12, $w12\n" + "hadd_s.d $w13, $w13, $w13\n" + "hadd_s.d $w14, $w14, $w14\n" + "hadd_s.d $w15, $w15, $w15\n" + "pckev.w $w0, $w1, $w0\n" + "pckev.w $w2, $w3, $w2\n" + "pckev.w $w4, $w5, $w4\n" + "pckev.w $w6, $w7, $w6\n" + "pckev.w $w8, $w9, $w8\n" + "pckev.w $w10, $w11, $w10\n" + "pckev.w $w12, $w13, $w12\n" + "pckev.w $w14, $w15, $w14\n" + "hadd_s.d $w0, $w0, $w0\n" + "hadd_s.d $w2, $w2, $w2\n" + "hadd_s.d $w4, $w4, $w4\n" + "hadd_s.d $w6, $w6, $w6\n" + "hadd_s.d $w8, $w8, $w8\n" + "hadd_s.d $w10, $w10, $w10\n" + "hadd_s.d $w12, $w12, $w12\n" + "hadd_s.d $w14, $w14, $w14\n" + // 4 more pckev instructions follow in both paths below. + + // Check if start_depth==0 to decide whether we will load + // existing accumulators from memory. + "bnez %[start_depth], " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "f\n" + + "pckev.w $w0, $w2, $w0\n" + "pckev.w $w1, $w6, $w4\n" + "pckev.w $w2, $w10, $w8\n" + "pckev.w $w3, $w14, $w12\n" + + "b " GEMMLOWP_LABEL_STORE "f\n" + + GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES ":\n" + // Load accumulators from memory. + "ld.w $w16, 0(%[dst_ptr0])\n" + "pckev.w $w0, $w2, $w0\n" + "ld.w $w17, 0(%[dst_ptr1])\n" + "pckev.w $w1, $w6, $w4\n" + "ld.w $w18, 0(%[dst_ptr2])\n" + "pckev.w $w2, $w10, $w8\n" + "ld.w $w19, 0(%[dst_ptr3])\n" + "pckev.w $w3, $w14, $w12\n" + + // Add them to internal accumulators. + "addv.w $w0, $w0, $w16\n" + "addv.w $w1, $w1, $w17\n" + "addv.w $w2, $w2, $w18\n" + "addv.w $w3, $w3, $w19\n" + + GEMMLOWP_LABEL_STORE ":\n" + // Store accumulators. + "st.w $w0, 0(%[dst_ptr0])\n" + "st.w $w1, 0(%[dst_ptr1])\n" + "st.w $w2, 0(%[dst_ptr2])\n" + "st.w $w3, 0(%[dst_ptr3])\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [run_depth] "+r"(run_depth) + : // inputs + [dst_ptr0] "r"(dst_ptr), [dst_ptr1] "r"(dst_ptr + dst_col_stride), + [dst_ptr2] "r"(dst_ptr + dst_col_stride * 2), + [dst_ptr3] "r"(dst_ptr + dst_col_stride * 3), + [start_depth] "r"(start_depth) + : // clobbers + "memory", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", + "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", + "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", + "$f27", "$f28", "$f29", "$f30", "$f31"); +#undef GEMMLOWP_LABEL_LOOP +#undef GEMMLOWP_LABEL_AFTER_LOOP_LAST16 +#undef GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES +#undef GEMMLOWP_LABEL_STORE + } +}; + #undef GEMMLOWP_MIPS_XADDU #undef GEMMLOWP_MIPS_XADDIU #undef GEMMLOWP_MIPS_XSLL diff --git a/internal/kernel_neon.h b/internal/kernel_neon.h index 3cd48f4..9859637 100644 --- a/internal/kernel_neon.h +++ b/internal/kernel_neon.h @@ -55,6 +55,7 @@ struct NEON_32_Kernel12x4Depth2 : KernelBase { #define GEMMLOWP_LABEL_AFTER_LOOP "4" assert(dst_row_stride == 1); + (void)dst_row_stride; asm volatile( // Overview of register layout: // @@ -308,6 +309,7 @@ struct NEON_32_Kernel12x4Depth2Assuming12BitProducts : KernelBase { ScopedProfilingLabel label( "optimized kernel (NEON 12x4, assuming 12-bit products)"); assert(dst_row_stride == 1); + (void)dst_row_stride; // See comments above for why we need local numerical labels in our asm. #define GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS "1" @@ -678,6 +680,7 @@ struct NEON_32bit_GEMM_Int8Operands_LhsNonzero : KernelBase { std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, const std::uint8_t* rhs_ptr, std::size_t start_depth, std::size_t run_depth) const override { + (void)dst_row_stride; #define GEMMLOWP_LABEL_AFTER_LOOP "1" #define GEMMLOWP_LABEL_LOOP "2" #define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3" @@ -921,6 +924,17 @@ struct NEON_32bit_GEMM_Int8Operands_LhsNonzero : KernelBase { } }; +// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that +// requires that user inputs were originally int8. This avoids the uint8->int8 +// conversion in the pack step. +struct NEON_32bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs + : NEON_32bit_GEMM_Int8Operands_LhsNonzero { + typedef KernelFormat< + KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormatInt8Inputs<CellFormat<2, 16, CellOrder::WidthMajor>, 1> > + Format; +}; + #endif // GEMMLOWP_NEON_32 // The kernels here are specifically arm 64bit assembly, not arm 32bit. @@ -940,6 +954,7 @@ struct NEON_64bit_GEMM_Int8Operands_LhsNonzero : KernelBase { std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, const std::uint8_t* rhs_ptr, std::size_t start_depth, std::size_t run_depth) const override { + (void)dst_row_stride; #define GEMMLOWP_LABEL_AFTER_LOOP_LAST16 "1" #define GEMMLOWP_LABEL_LOOP "2" #define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3" @@ -1261,6 +1276,17 @@ struct NEON_64bit_GEMM_Int8Operands_LhsNonzero : KernelBase { } }; +// Same as NEON_32bit_GEMM_Int8Operands_LhsNonzero, but uses a side format that +// requires that user inputs were originally int8. This avoids the uint8->int8 +// conversion in the pack step. +struct NEON_64bit_GEMM_Int8Operands_LhsNonzero_Int8Inputs + : NEON_64bit_GEMM_Int8Operands_LhsNonzero { + typedef KernelFormat< + KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, + KernelSideFormatInt8Inputs<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > + Format; +}; + // Our main GEMM kernel. struct NEON_64_Kernel12x8Depth2 : KernelBase { typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>, @@ -1274,6 +1300,7 @@ struct NEON_64_Kernel12x8Depth2 : KernelBase { std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, const std::uint8_t* rhs_ptr, std::size_t start_depth, std::size_t run_depth) const override { + (void)dst_row_stride; ScopedProfilingLabel label("optimized kernel (NEON 12x8)"); // See comments above for why we need local numerical labels in our asm. #define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1" @@ -1611,6 +1638,274 @@ struct NEON_64_Kernel12x8Depth2 : KernelBase { } }; +#ifdef GEMMLOWP_DOTPROD_KERNEL +#ifndef __ARM_FEATURE_DOTPROD +#error This kernel requires ARM dot-product instructions. Enable them by \ + adding '+dotprod' to a compiler flag, e.g. -march=armv8.2-a+dotprod . \ + Note that Clang up to version 7 fails to define the corresponding \ + preprocessor token __ARM_FEATURE_DOTPROD, so you will still have to define \ + it manually. +#endif +// Kernels utilizing the Armv8.2 Dot Product extension. +// +// The dot product instructions work by taking 4 consecutive 8-bit depth +// values from each operand, multiplying the 4 pairs together and +// accumulating all the results into the corresponding 32-bit accumulator +// lane. As such, the operation is identical to a 32-bit instruction (like +// FMLA used in SGEMM), except that 4 depth values are processed at a time +// instead of 1. + +// Thus, this first kernel is a carbon copy of +// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good +// performance for most processors) below with the opcode (fmla -> udot) and +// types (float32 -> uint8/uint32) changed. +// +// A signed version of this kernel could be produced by replacing "udot" +// with "sdot" - performance should be identical to this udot kernel. +struct NEON_64_Kernel12x8Depth4_dotprod : KernelBase { + typedef KernelFormat<KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>, + KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> > + Format; + + const char* Name() const override { return "NEON, 12x8, depth 4, dotprod"; } + + void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, std::size_t dst_col_stride, + const std::uint8_t* lhs_ptr, const std::uint8_t* rhs_ptr, std::size_t start_depth, + std::size_t depth) const override { + (void)dst_row_stride; + ScopedProfilingLabel label("optimized kernel (NEON 12x8, depth 4, dotprod)"); +// See comments above for why we need local numerical labels in our asm. +#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1" +#define GEMMLOWP_LABEL_BEFORE_LOOP "2" +#define GEMMLOWP_LABEL_LOOP "3" +#define GEMMLOWP_LABEL_AFTER_LOOP "4" + + assert(dst_row_stride == 1); + asm volatile( + // Multiply dst_col_stride by 4 == sizeof(int32) to use + // it as a byte offset below. + "lsl %[dst_col_stride], %[dst_col_stride], #2\n" + + "cmp %[start_depth], #0\n" + "beq " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n" + + // Load accumulators + "mov x1, %[dst_ptr]\n" + "mov x0, x1\n" + "ld1 {v8.16b}, [x0], #16\n" + "ld1 {v16.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v24.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v9.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v17.16b}, [x0], #16\n" + "ld1 {v25.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v10.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v18.16b}, [x0], #16\n" + "ld1 {v26.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v11.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v19.16b}, [x0], #16\n" + "ld1 {v27.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v12.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v20.16b}, [x0], #16\n" + "ld1 {v28.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v13.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v21.16b}, [x0], #16\n" + "ld1 {v29.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v14.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "ld1 {v22.16b}, [x0], #16\n" + "ld1 {v30.16b}, [x0]\n" + "mov x0, x1\n" + "ld1 {v15.16b}, [x0], #16\n" + "ld1 {v23.16b}, [x0], #16\n" + "ld1 {v31.16b}, [x0]\n" + + "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n" + + GEMMLOWP_LABEL_CLEAR_ACCUMULATORS ":\n" + + // Clear accumulator registers (see layout below) + "dup v8.4s, wzr\n" + "dup v9.4s, wzr\n" + "dup v10.4s, wzr\n" + "dup v11.4s, wzr\n" + "dup v12.4s, wzr\n" + "dup v13.4s, wzr\n" + "dup v14.4s, wzr\n" + "dup v15.4s, wzr\n" + "dup v16.4s, wzr\n" + "dup v17.4s, wzr\n" + "dup v18.4s, wzr\n" + "dup v19.4s, wzr\n" + "dup v20.4s, wzr\n" + "dup v21.4s, wzr\n" + "dup v22.4s, wzr\n" + "dup v23.4s, wzr\n" + "dup v24.4s, wzr\n" + "dup v25.4s, wzr\n" + "dup v26.4s, wzr\n" + "dup v27.4s, wzr\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + GEMMLOWP_LABEL_BEFORE_LOOP ":\n" + + "subs %w[depth], %w[depth], #4\n" + + // The start of the loop assumes first Rhs cell is already loaded, so + // do it here for first iteration. + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" + + // And the same for the first Lhs cell. + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + + "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n" + + GEMMLOWP_LABEL_LOOP ":\n" + + // Start the MACs at the head of the loop - 1st cell from each side + // already loaded. + ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n" + ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell. + ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n" + ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell. + ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n" + ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell. + ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n" + ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load + // for the next iteration early. + ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n" + ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n" + ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n" + ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n" + ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n" + ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n" + ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n" + ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n" + ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n" + ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n" + ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n" + ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n" + "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell - + // load for the next iteration early. + ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n" + ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n" + + // Loop. Decrement loop index (depth) by 4 as udot processes 4 + // depth values. + "subs %w[depth], %w[depth], #4\n" + ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n" + ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n" + + "bne " GEMMLOWP_LABEL_LOOP "b\n" + + GEMMLOWP_LABEL_AFTER_LOOP ":\n" + + // Final iteration. v0 and v2 were already loaded, don't load + // them again, don't read past the end of buffers. + ".word 0x6f80e048 // udot v8.4s, v2.16b, v0.4b[0]\n" + ".word 0x6fa0e049 // udot v9.4s, v2.16b, v0.4b[1]\n" + "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell. + ".word 0x6f80e84a // udot v10.4s, v2.16b, v0.4b[2]\n" + ".word 0x6fa0e84b // udot v11.4s, v2.16b, v0.4b[3]\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell. + ".word 0x6f81e04c // udot v12.4s, v2.16b, v1.4b[0]\n" + ".word 0x6fa1e04d // udot v13.4s, v2.16b, v1.4b[1]\n" + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell. + ".word 0x6f81e84e // udot v14.4s, v2.16b, v1.4b[2]\n" + ".word 0x6fa1e84f // udot v15.4s, v2.16b, v1.4b[3]\n" + ".word 0x6f80e070 // udot v16.4s, v3.16b, v0.4b[0]\n" + ".word 0x6fa0e071 // udot v17.4s, v3.16b, v0.4b[1]\n" + ".word 0x6f80e872 // udot v18.4s, v3.16b, v0.4b[2]\n" + ".word 0x6fa0e873 // udot v19.4s, v3.16b, v0.4b[3]\n" + ".word 0x6f81e074 // udot v20.4s, v3.16b, v1.4b[0]\n" + ".word 0x6fa1e075 // udot v21.4s, v3.16b, v1.4b[1]\n" + ".word 0x6f81e876 // udot v22.4s, v3.16b, v1.4b[2]\n" + ".word 0x6fa1e877 // udot v23.4s, v3.16b, v1.4b[3]\n" + ".word 0x6f80e098 // udot v24.4s, v4.16b, v0.4b[0]\n" + ".word 0x6fa0e099 // udot v25.4s, v4.16b, v0.4b[1]\n" + ".word 0x6f80e89a // udot v26.4s, v4.16b, v0.4b[2]\n" + ".word 0x6fa0e89b // udot v27.4s, v4.16b, v0.4b[3]\n" + ".word 0x6f81e09c // udot v28.4s, v4.16b, v1.4b[0]\n" + ".word 0x6fa1e09d // udot v29.4s, v4.16b, v1.4b[1]\n" + + // Loop. Decrement loop index (depth) by 4 as udot processes 4 + // depth values. + "subs %w[depth], %w[depth], #4\n" + ".word 0x6f81e89e // udot v30.4s, v4.16b, v1.4b[2]\n" + ".word 0x6fa1e89f // udot v31.4s, v4.16b, v1.4b[3]\n" + + // Store accumulators + "mov x1, %[dst_ptr]\n" + "mov x0, x1\n" + "st1 {v8.16b}, [x0], #16\n" + "st1 {v16.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v24.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v9.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v17.16b}, [x0], #16\n" + "st1 {v25.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v10.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v18.16b}, [x0], #16\n" + "st1 {v26.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v11.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v19.16b}, [x0], #16\n" + "st1 {v27.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v12.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v20.16b}, [x0], #16\n" + "st1 {v28.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v13.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v21.16b}, [x0], #16\n" + "st1 {v29.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v14.16b}, [x0], #16\n" + "add x1, x1, %[dst_col_stride]\n" + "st1 {v22.16b}, [x0], #16\n" + "st1 {v30.16b}, [x0]\n" + "mov x0, x1\n" + "st1 {v15.16b}, [x0], #16\n" + "st1 {v23.16b}, [x0], #16\n" + "st1 {v31.16b}, [x0]\n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [depth] "+r"(depth) + : // inputs + [dst_ptr] "r"(dst_ptr), [dst_col_stride] "r"(dst_col_stride), [start_depth] "r"(start_depth) + : // clobbers + "cc", "memory", "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } +}; +#endif // GEMMLOWP_DOTPROD_KERNEL + #endif // GEMMLOWP_NEON_64 } // namespace gemmlowp diff --git a/internal/kernel_sse.h b/internal/kernel_sse.h index b879fd7..ba7959b 100644 --- a/internal/kernel_sse.h +++ b/internal/kernel_sse.h @@ -43,6 +43,7 @@ struct SSE4_32_Kernel4x4Depth2 : KernelBase { std::size_t run_depth) const override { ScopedProfilingLabel label("optimized kernel"); assert(dst_row_stride == 1); + (void)dst_row_stride; std::int32_t run_depth_cells = run_depth / Format::kDepth; /* Main loop */ @@ -217,6 +218,7 @@ struct SSE4_64_Kernel12x4Depth2 : KernelBase { std::size_t run_depth) const override { ScopedProfilingLabel label("optimized kernel"); assert(dst_row_stride == 1); + (void)dst_row_stride; const std::int64_t run_depth_cells = run_depth / Format::kDepth; const std::int64_t dst_col_stride_q = dst_col_stride; diff --git a/internal/multi_thread_gemm.h b/internal/multi_thread_gemm.h index 791402f..97183e7 100644 --- a/internal/multi_thread_gemm.h +++ b/internal/multi_thread_gemm.h @@ -19,23 +19,43 @@ #ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_ #define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_ +#include <atomic> // NOLINT +#include <chrono> // NOLINT +#include <thread> // NOLINT #include <vector> #include "single_thread_gemm.h" namespace gemmlowp { -// On X86 and ARM platforms we enable a busy-wait spinlock before waiting on a -// pthread conditional variable. In order to implement that correctly we need -// to put some explicit memory load/store barriers. +// This value was empirically derived on an end-to-end application benchmark. +// That this number of cycles means that we may be sleeping substantially longer +// than a scheduler timeslice's duration is not necessarily surprising. The +// idea is to pick up quickly new work after having finished the previous +// workload. When it's new work within the same GEMM as the previous work, the +// time interval that we might be busy-waiting is very small, so for that +// purpose it would be more than enough to sleep for 1 million cycles. +// That is all what we would observe on a GEMM benchmark. However, in a real +// application, after having finished a GEMM, we might do unrelated work for +// a little while, then start on a new GEMM. Think of a neural network +// application performing inference, where many but not all layers are +// implemented by a GEMM. In such cases, our worker threads might be idle for +// longer periods of time before having work again. If we let them passively +// wait, on a mobile device, the CPU scheduler might aggressively clock down +// or even turn off the CPU cores that they were running on. That would result +// in a long delay the next time these need to be turned back on for the next +// GEMM. So we need to strike a balance that reflects typical time intervals +// between consecutive GEMM invokations, not just intra-GEMM considerations. +// Of course, we need to balance keeping CPUs spinning longer to resume work +// faster, versus passively waiting to conserve power. +const int kMaxBusyWaitNOPs = 4 * 1000 * 1000; + +// On X86 and ARM platforms we may use NOP instructions to know how long we +// are busy-waiting. #if defined(GEMMLOWP_ALLOW_INLINE_ASM) && !defined(GEMMLOWP_NO_BUSYWAIT) && \ (defined(GEMMLOWP_ARM) || defined(GEMMLOWP_X86)) -#define GEMMLOWP_USE_BUSYWAIT - -const int kMaxBusyWaitNOPs = 32 * 1000 * 1000; - #define GEMMLOWP_NOP "nop\n" #define GEMMLOWP_STRING_CONCAT_4(X) X X X X @@ -43,46 +63,26 @@ const int kMaxBusyWaitNOPs = 32 * 1000 * 1000; #define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4) #define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16) -inline int Do256NOPs() { +inline int DoSomeNOPs() { asm volatile(GEMMLOWP_NOP64); return 64; } #undef GEMMLOWP_STRING_CONCAT_4 -#undef GEMMLOWP_NOP256 #undef GEMMLOWP_NOP64 #undef GEMMLOWP_NOP16 #undef GEMMLOWP_NOP4 #undef GEMMLOWP_NOP -inline void WriteBarrier() { -#if defined(_MSC_VER) - MemoryBarrier(); -#elif defined(GEMMLOWP_ARM_32) - asm volatile("" ::: "memory"); -#elif defined(GEMMLOWP_ARM_64) - asm volatile("dmb ishst" ::: "memory"); -#elif defined(GEMMLOWP_X86) - asm volatile("sfence" ::: "memory"); -#else -#error "Unsupported architecture for WriteBarrier." -#endif -} +#else // May not use asm NOP. -inline void ReadBarrier() { -#if defined(_MSC_VER) - MemoryBarrier(); -#elif defined(GEMMLOWP_ARM_32) - asm volatile("" ::: "memory"); -#elif defined(GEMMLOWP_ARM_64) - asm volatile("dmb ishld" ::: "memory"); -#elif defined(GEMMLOWP_X86) - asm volatile("lfence" ::: "memory"); -#else -#error "Unsupported architecture for ReadBarrier." -#endif +// If we can't use NOPs, let's use a non-inline function call as a basic +// thing that has some vaguely known, nonzero cost. +GEMMLOWP_NOINLINE +inline int DoSomeNOPs() { + // Pretend that calling an empty function takes as long as 16 NOPs... + return 16; } - #endif // Waits until *var != initial_value. @@ -108,37 +108,29 @@ inline void ReadBarrier() { // so as to avoid permanently spinning. // template <typename T> -T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond, - pthread_mutex_t* mutex) { -#ifdef GEMMLOWP_USE_BUSYWAIT - // If we are on a platform that supports it, spin for some time. - { - int nops = 0; - // First, trivial case where the variable already changed value. - T new_value = *var; +T WaitForVariableChange(std::atomic<T>* var, T initial_value, + pthread_cond_t* cond, pthread_mutex_t* mutex) { + // First, trivial case where the variable already changed value. + T new_value = var->load(std::memory_order_acquire); + if (new_value != initial_value) { + return new_value; + } + // Then try busy-waiting. + int nops = 0; + while (nops < kMaxBusyWaitNOPs) { + nops += DoSomeNOPs(); + new_value = var->load(std::memory_order_acquire); if (new_value != initial_value) { - ReadBarrier(); return new_value; } - // Then try busy-waiting. - while (nops < kMaxBusyWaitNOPs) { - nops += Do256NOPs(); - new_value = *var; - if (new_value != initial_value) { - ReadBarrier(); - return new_value; - } - } } -#endif // Finally, do real passive waiting. pthread_mutex_lock(mutex); - T new_value = *var; - if (new_value == initial_value) { + new_value = var->load(std::memory_order_acquire); + while (new_value == initial_value) { pthread_cond_wait(cond, mutex); - new_value = *var; - assert(new_value != initial_value); + new_value = var->load(std::memory_order_acquire); } pthread_mutex_unlock(mutex); return new_value; @@ -147,73 +139,74 @@ T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond, // A BlockingCounter lets one thread to wait for N events to occur. // This is how the master thread waits for all the worker threads // to have finished working. +// The waiting is done using a naive spinlock waiting for the atomic +// count_ to hit the value 0. This is acceptable because in our usage +// pattern, BlockingCounter is used only to synchronize threads after +// short-lived tasks (performing parts of the same GEMM). It is not used +// for synchronizing longer waits (resuming work on the next GEMM). class BlockingCounter { public: - BlockingCounter() : count_(0), initial_count_(0) { - pthread_cond_init(&cond_, nullptr); - pthread_mutex_init(&mutex_, nullptr); - } - - ~BlockingCounter() { - pthread_cond_destroy(&cond_); - pthread_mutex_destroy(&mutex_); - } + BlockingCounter() : count_(0) {} // Sets/resets the counter; initial_count is the number of // decrementing events that the Wait() call will be waiting for. void Reset(std::size_t initial_count) { - pthread_mutex_lock(&mutex_); - assert(count_ == 0); - initial_count_ = initial_count; - count_ = initial_count_; - pthread_mutex_unlock(&mutex_); + std::size_t old_count_value = count_.load(std::memory_order_relaxed); + assert(old_count_value == 0); + (void)old_count_value; + count_.store(initial_count, std::memory_order_release); } // Decrements the counter; if the counter hits zero, signals - // the thread that was waiting for that, and returns true. + // the threads that were waiting for that, and returns true. // Otherwise (if the decremented count is still nonzero), // returns false. bool DecrementCount() { - pthread_mutex_lock(&mutex_); - assert(count_ > 0); - count_--; -#ifdef GEMMLOWP_USE_BUSYWAIT - WriteBarrier(); -#endif - if (count_ == 0) { - pthread_cond_signal(&cond_); - } - bool retval = count_ == 0; - pthread_mutex_unlock(&mutex_); - return retval; + std::size_t old_count_value = + count_.fetch_sub(1, std::memory_order_acq_rel); + assert(old_count_value > 0); + std::size_t count_value = old_count_value - 1; + return count_value == 0; } // Waits for the N other threads (N having been set by Reset()) // to hit the BlockingCounter. void Wait() { ScopedProfilingLabel label("BlockingCounter::Wait"); - while (count_) { -#ifdef GEMMLOWP_USE_BUSYWAIT - ReadBarrier(); -#else - // This is likely unnecessary, but is kept to ensure regressions are not - // introduced. -#ifndef _WIN32 - asm volatile("" ::: "memory"); -#endif -#endif - const std::size_t count_value = count_; - if (count_value) { - WaitForVariableChange(&count_, count_value, &cond_, &mutex_); + // Busy-wait until the count value is 0. + int nops = 0; + while (count_.load(std::memory_order_acquire)) { + nops += DoSomeNOPs(); + if (nops > kMaxBusyWaitNOPs) { + nops = 0; + // If we are unlucky, the blocking thread (that calls DecrementCount) + // and the blocked thread (here, calling Wait) may be scheduled on + // the same CPU, so the busy-waiting of the present thread may prevent + // the blocking thread from resuming and unblocking. + // If we are even unluckier, the priorities of the present thread + // might be higher than that of the blocking thread, so just yielding + // wouldn't allow the blocking thread to resume. So we sleep for + // a substantial amount of time in that case. Notice that we only + // do so after having busy-waited for kMaxBusyWaitNOPs, which is + // typically several milliseconds, so sleeping 1 more millisecond + // isn't terrible at that point. + // + // How this is mitigated in practice: + // In practice, it is well known that the application should be + // conservative in choosing how many threads to tell gemmlowp to use, + // as it's hard to know how many CPU cores it will get to run on, + // on typical mobile devices. + // It seems impossible for gemmlowp to make this choice automatically, + // which is why gemmlowp's default is to use only 1 thread, and + // applications may override that if they know that they can count on + // using more than that. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); } } } private: - pthread_cond_t cond_; - pthread_mutex_t mutex_; - std::size_t count_; - std::size_t initial_count_; + std::atomic<std::size_t> count_; }; // A workload for a worker. @@ -253,11 +246,15 @@ class Worker { // Changes State; may be called from either the worker thread // or the master thread; however, not all state transitions are legal, // which is guarded by assertions. - void ChangeState(State new_state) { + // + // The Task argument is to be used only with new_state==HasWork. + // It specifies the Task being handed to this Worker. + void ChangeState(State new_state, Task* task = nullptr) { ScopedProfilingLabel label("Worker::ChangeState"); pthread_mutex_lock(&state_mutex_); - assert(new_state != state_); - switch (state_) { + State old_state = state_.load(std::memory_order_relaxed); + assert(old_state != new_state); + switch (old_state) { case State::ThreadStartup: assert(new_state == State::Ready); break; @@ -272,18 +269,33 @@ class Worker { default: abort(); } - state_ = new_state; - pthread_cond_signal(&state_cond_); - if (state_ == State::Ready) { - counter_to_decrement_when_ready_->DecrementCount(); + switch (new_state) { + case State::Ready: + if (task_) { + // Doing work is part of reverting to 'ready' state. + task_->Run(); + task_ = nullptr; + } + break; + case State::HasWork: + assert(!task_); + task->local_allocator = &local_allocator_; + task_ = task; + break; + default: + break; } + state_.store(new_state, std::memory_order_relaxed); + pthread_cond_broadcast(&state_cond_); pthread_mutex_unlock(&state_mutex_); + if (new_state == State::Ready) { + counter_to_decrement_when_ready_->DecrementCount(); + } } // Thread entry point. void ThreadFunc() { ScopedProfilingLabel label("Worker::ThreadFunc"); - RegisterCurrentThreadForProfiling(); ChangeState(State::Ready); @@ -299,9 +311,6 @@ class Worker { switch (state_to_act_upon) { case State::HasWork: // Got work to do! So do it, and then revert to 'Ready' state. - assert(task_); - task_->Run(); - task_ = nullptr; ChangeState(State::Ready); break; case State::ExitAsSoonAsPossible: @@ -318,17 +327,7 @@ class Worker { } // Called by the master thead to give this worker work to do. - // It is only legal to call this if the worker - void StartWork(Task* task) { - assert(!task_); - task->local_allocator = &local_allocator_; - task_ = task; -#ifdef GEMMLOWP_USE_BUSYWAIT - WriteBarrier(); -#endif - assert(state_ == State::Ready); - ChangeState(State::HasWork); - } + void StartWork(Task* task) { ChangeState(State::HasWork, task); } private: // The underlying thread. @@ -342,7 +341,10 @@ class Worker { pthread_mutex_t state_mutex_; // The state enum tells if we're currently working, waiting for work, etc. - State state_; + // Its concurrent accesses by the worker and main threads are guarded by + // state_mutex_, and can thus use memory_order_relaxed. This still needs + // to be a std::atomic because we use WaitForVariableChange. + std::atomic<State> state_; // Each thread had a local allocator so they can allocate temporary // buffers without blocking each other. @@ -359,9 +361,7 @@ class Worker { // waits for all of them to finish. // // See MultiThreadGemmContextBase for how other WorkersPool implementations can -// be used. Note that in those implementations, StartWorker can be free to -// ignore the <index> value; that is, the caller of WorkersPool does not rely on -// <index> to order tasks with equal <index>. +// be used. class WorkersPool { public: WorkersPool() {} @@ -372,18 +372,41 @@ class WorkersPool { } } - void Execute(const std::vector<Task*>& tasks) { - assert(tasks.size() >= 1); + // Just executes the tasks. Does not destroy them. Similar to + // ruy::ThreadPool::Execute. + template <typename TaskType> + void Execute(int tasks_count, TaskType* tasks) { + assert(tasks_count >= 1); // One of the tasks will be run on the current thread. - std::size_t workers_count = tasks.size() - 1; + std::size_t workers_count = tasks_count - 1; CreateWorkers(workers_count); assert(workers_count <= workers_.size()); counter_to_decrement_when_ready_.Reset(workers_count); - int n = 0; - std::for_each(tasks.begin(), --tasks.end(), - [this, &n](Task* task) { workers_[n++]->StartWork(task); }); + for (std::size_t i = 0; i < tasks_count - 1; i++) { + workers_[i]->StartWork(&tasks[i]); + } // Execute the remaining workload immediately on the current thread. - Task* task = tasks.back(); + Task* task = &tasks[tasks_count - 1]; + task->local_allocator = &main_thread_task_allocator_; + task->Run(); + // Wait for the workers submitted above to finish. + counter_to_decrement_when_ready_.Wait(); + } + + // Legacy: executes the tasks and destroys them + void LegacyExecuteAndDestroyTasks(const std::vector<Task*>& tasks) { + std::size_t tasks_count = tasks.size(); + assert(tasks_count >= 1); + // One of the tasks will be run on the current thread. + std::size_t workers_count = tasks_count - 1; + CreateWorkers(workers_count); + assert(workers_count <= workers_.size()); + counter_to_decrement_when_ready_.Reset(workers_count); + for (int i = 0; i < tasks_count - 1; i++) { + workers_[i]->StartWork(tasks[i]); + } + // Execute the remaining workload immediately on the current thread. + Task* task = tasks[tasks_count - 1]; task->local_allocator = &main_thread_task_allocator_; task->Run(); // Wait for the workers submitted above to finish. @@ -393,6 +416,11 @@ class WorkersPool { std::for_each(tasks.begin(), tasks.end(), [](Task* task) { delete task; }); } + // Legacy old name of LegacyExecuteAndDestroyTasks + void Execute(const std::vector<Task*>& tasks) { + LegacyExecuteAndDestroyTasks(tasks); + } + private: // Ensures that the pool has at least the given count of workers. // If any new worker has to be created, this function waits for it to diff --git a/internal/output.h b/internal/output.h index dcfe2b5..92bf7b9 100644 --- a/internal/output.h +++ b/internal/output.h @@ -22,6 +22,7 @@ #include <cmath> #include <tuple> #include <type_traits> +#include <typeinfo> #include "../fixedpoint/fixedpoint.h" #include "../public/output_stages.h" @@ -179,7 +180,47 @@ struct OutputStageEvalBufferImpl<OutputStageScaleInt32ByFixedPointAndExponent, int right_shift; }; -// Implementation of OutputStageSaturatingCastToUint8 for scalar data +template <int Rows, int Cols, VectorShape Shape> +struct OutputStageEvalImpl< + OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>, + RegisterBlock<std::int32_t, Rows, Cols>> { + typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; + typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; + + typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> OutputStage; + + OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} + + OutputType Eval(InputType input, int row, int col) const { + OutputType output; + const int pos = Shape == VectorShape::Row ? col : row; + using RegisterType = typename InputType::RegisterType; + const RegisterType result_offset_after_shift = + Dup<RegisterType>(output_stage.result_offset_after_shift); + auto left_shift = + LoadForBroadcasting<InputType>(output_stage.result_exponent, pos); + auto right_shift = + LoadForBroadcasting<InputType>(output_stage.result_exponent, pos); + const auto result_fixedpoint_multiplier = LoadForBroadcasting<InputType>( + output_stage.result_fixedpoint_multiplier, pos); + for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) { + left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0); + right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0); + } + const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul( + BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier); + const auto rdpot_val = + BroadcastRoundingDivideByPOT(mulhigh_val, right_shift); + for (int i = 0; i < InputType::kRegisterCount; i++) { + output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift); + } + return output; + } + + const OutputStage& output_stage; +}; + +// Implementation of OutputStageSaturatingCastToUint8 for scalar data. template <int Size> struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, RegisterBuffer<std::int32_t, Size>> { @@ -202,7 +243,30 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, } }; -// Implementation of OutputStageSaturatingCastToInt16 for scalar data +// Implementation of OutputStageSaturatingCastToInt8 for scalar data. +template <int Size> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, + RegisterBuffer<std::int32_t, Size>> { + typedef RegisterBuffer<std::int32_t, Size> InputType; + typedef RegisterBuffer<std::int8_t, Size> OutputType; + static_assert(InputType::kRegisterLanes == 1, + "This path is only for scalar values"); + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + for (int i = 0; i < InputType::kRegisterCount; i++) { + std::int32_t data = input.reg[i]; + output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data; + } + return output; + } +}; + +// Implementation of OutputStageSaturatingCastToInt16 for scalar data. template <int Size> struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, RegisterBuffer<std::int32_t, Size>> { @@ -225,6 +289,28 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, } }; +// Implementation of OutputStageTruncatingCastToUint8 for scalar data +template <int Size> +struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, + RegisterBuffer<std::int32_t, Size>> { + typedef RegisterBuffer<std::int32_t, Size> InputType; + typedef RegisterBuffer<std::uint8_t, Size> OutputType; + static_assert(InputType::kRegisterLanes == 1, + "This path is only for scalar values"); + + typedef OutputStageTruncatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + for (int i = 0; i < InputType::kRegisterCount; i++) { + output.reg[i] = input.reg[i]; + } + return output; + } +}; + template <int Rows, int Cols, typename VectorType> struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>, RegisterBlock<std::int32_t, Rows, Cols>> { @@ -452,7 +538,7 @@ struct OutputPipelineExecutor { OutputPipelineExecutor(const OutputPipelineType& output_pipeline) : output_pipeline_eval_impl_(output_pipeline) {} - // RunOutputPipeline is the entry point into the output pipeline evaluation + // Execute is the entry point into the output pipeline evaluation // code. It should be the only thing that unpack code calls. It takes the // result // of the unpack stage and stores it into the destination matrix. diff --git a/internal/output_avx.h b/internal/output_avx.h new file mode 100644 index 0000000..b8f94fb --- /dev/null +++ b/internal/output_avx.h @@ -0,0 +1,19 @@ +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// output_avx.h: optimized AVX 2 specializations of the templates in output.h. + +#ifndef GEMMLOWP_INTERNAL_OUTPUT_AVX_H_ +#define GEMMLOWP_INTERNAL_OUTPUT_AVX_H_ + +#endif // GEMMLOWP_INTERNAL_OUTPUT_AVX_H_ diff --git a/internal/output_msa.h b/internal/output_msa.h index 4c8eb5d..0540bb3 100644 --- a/internal/output_msa.h +++ b/internal/output_msa.h @@ -38,18 +38,14 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, // Signed saturate each 32-bit element to 9 bits // (this takes full care of non-negative elements). v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8); + // Zero out negative elements. + tmp = __builtin_msa_maxi_s_w(tmp, 0); // Pack every 32-bit element into 16 bits. tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp))); - // Detect negative elements with arithmetic shift right (we - // get a 16-bit mask of all zeroes or all ones for every element). - v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15); - // Zero out negative elements. - signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( - reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0)); // Pack every element into 8 bits. tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( - reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs))); + reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp))); // Return 4 uint8_t elements as uint32_t. output.reg[0] = __builtin_msa_copy_s_w(tmp, 0); return output; @@ -76,15 +72,12 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, // combining all 8 elements into one vector. tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo))); - // Detect negative elements with arithmetic shift right (we - // get a 16-bit mask of all zeroes or all ones for every element). - v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15); // Zero out negative elements. - signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( - reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0)); + tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( + reinterpret_cast<v8i16>(tmp_lo), 0)); // Pack every element into 8 bits. tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( - reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs))); + reinterpret_cast<v16i8>(tmp_lo), reinterpret_cast<v16i8>(tmp_lo))); // Return 8 uint8_t elements as 2 uint32_t's. output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0); output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1); @@ -102,15 +95,13 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0))); \ tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \ reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2))); \ - v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15); \ - v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15); \ - signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \ - reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \ - signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \ - reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \ - signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b( \ - reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0))); \ - out = reinterpret_cast<v16i8>(signs0); \ + tmp0 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( \ + reinterpret_cast<v8i16>(tmp0), 0)); \ + tmp2 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( \ + reinterpret_cast<v8i16>(tmp2), 0)); \ + tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( \ + reinterpret_cast<v16i8>(tmp2), reinterpret_cast<v16i8>(tmp0))); \ + out = reinterpret_cast<v16i8>(tmp0); \ } template <> @@ -166,8 +157,8 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, OutputType Eval(InputType input) const { OutputType output; // Signed saturate each 32-bit element to 16 bits. - v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w( - input.reg[0], 15)); + v8i16 tmp = + reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(input.reg[0], 15)); output.reg[0] = __builtin_msa_copy_s_h(tmp, 0); output.reg[1] = __builtin_msa_copy_s_h(tmp, 2); output.reg[2] = __builtin_msa_copy_s_h(tmp, 4); @@ -176,12 +167,12 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, } }; -#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \ - { \ - v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \ - v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \ - out = __builtin_msa_pckev_h( \ - reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \ +#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \ + { \ + v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \ + v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \ + out = __builtin_msa_pckev_h(reinterpret_cast<v8i16>(tmp1), \ + reinterpret_cast<v8i16>(tmp0)); \ } template <> @@ -241,6 +232,117 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, #undef GEMMLOWP_MIPS_SAT_I16_8 +template <> +struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, + RegBufferInt32<4>> { + typedef RegBufferInt32<4> InputType; + typedef RegBufferUint8<4> OutputType; + + typedef OutputStageTruncatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + // Pack every 32-bit element into 16 bits. + v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( + reinterpret_cast<v8i16>(input.reg[0]), + reinterpret_cast<v8i16>(input.reg[0]))); + // Pack every element into 8 bits. + tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( + reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp))); + // Return 4 uint8_t elements as uint32_t. + output.reg[0] = __builtin_msa_copy_s_w(tmp, 0); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, + RegBufferInt32<8>> { + typedef RegBufferInt32<8> InputType; + typedef RegBufferUint8<8> OutputType; + + typedef OutputStageTruncatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + // Pack every 32-bit element into 16 bits. + v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( + reinterpret_cast<v8i16>(input.reg[1]), + reinterpret_cast<v8i16>(input.reg[0]))); + // Pack every element into 8 bits. + tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( + reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp))); + // Return 8 uint8_t elements as 2 uint32_t's. + output.reg[0] = __builtin_msa_copy_s_w(tmp, 0); + output.reg[1] = __builtin_msa_copy_s_w(tmp, 1); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, + RegBufferInt32<16>> { + typedef RegBufferInt32<16> InputType; + typedef RegBufferUint8<16> OutputType; + + typedef OutputStageTruncatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + // Pack every 32-bit element into 16 bits. + v8i16 tmp0 = __builtin_msa_pckev_h( + reinterpret_cast<v8i16>(input.reg[1]), + reinterpret_cast<v8i16>(input.reg[0])); + v8i16 tmp1 = __builtin_msa_pckev_h( + reinterpret_cast<v8i16>(input.reg[3]), + reinterpret_cast<v8i16>(input.reg[2])); + // Pack every element into 8 bits. + output.reg[0] = __builtin_msa_pckev_b( + reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0)); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, + RegBufferInt32<32>> { + typedef RegBufferInt32<32> InputType; + typedef RegBufferUint8<32> OutputType; + + typedef OutputStageTruncatingCastToUint8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + // Pack every 32-bit element into 16 bits. + v8i16 tmp0 = __builtin_msa_pckev_h( + reinterpret_cast<v8i16>(input.reg[1]), + reinterpret_cast<v8i16>(input.reg[0])); + v8i16 tmp1 = __builtin_msa_pckev_h( + reinterpret_cast<v8i16>(input.reg[3]), + reinterpret_cast<v8i16>(input.reg[2])); + v8i16 tmp2 = __builtin_msa_pckev_h( + reinterpret_cast<v8i16>(input.reg[5]), + reinterpret_cast<v8i16>(input.reg[4])); + v8i16 tmp3 = __builtin_msa_pckev_h( + reinterpret_cast<v8i16>(input.reg[7]), + reinterpret_cast<v8i16>(input.reg[6])); + // Pack every element into 8 bits. + output.reg[0] = __builtin_msa_pckev_b( + reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0)); + output.reg[1] = __builtin_msa_pckev_b( + reinterpret_cast<v16i8>(tmp3), reinterpret_cast<v16i8>(tmp2)); + return output; + } +}; + template <typename DstType> struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, @@ -474,50 +576,50 @@ struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> { } } else { // top-left 4x4 - v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1], - src.buf.reg[0])); - v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3], - src.buf.reg[2])); + v4i32 t0 = reinterpret_cast<v4i32>( + __builtin_msa_ilvr_h(src.buf.reg[1], src.buf.reg[0])); + v4i32 t1 = reinterpret_cast<v4i32>( + __builtin_msa_ilvr_h(src.buf.reg[3], src.buf.reg[2])); v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0)); v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0)); // top-right 4x4 - v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5], - src.buf.reg[4])); - v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7], - src.buf.reg[6])); + v4i32 t2 = reinterpret_cast<v4i32>( + __builtin_msa_ilvr_h(src.buf.reg[5], src.buf.reg[4])); + v4i32 t3 = reinterpret_cast<v4i32>( + __builtin_msa_ilvr_h(src.buf.reg[7], src.buf.reg[6])); v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2)); v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2)); // bottom-left 4x4 - v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1], - src.buf.reg[0])); - v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3], - src.buf.reg[2])); + v4i32 t4 = reinterpret_cast<v4i32>( + __builtin_msa_ilvl_h(src.buf.reg[1], src.buf.reg[0])); + v4i32 t5 = reinterpret_cast<v4i32>( + __builtin_msa_ilvl_h(src.buf.reg[3], src.buf.reg[2])); v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4)); v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4)); // bottom-right 4x4 - v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5], - src.buf.reg[4])); - v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7], - src.buf.reg[6])); + v4i32 t6 = reinterpret_cast<v4i32>( + __builtin_msa_ilvl_h(src.buf.reg[5], src.buf.reg[4])); + v4i32 t7 = reinterpret_cast<v4i32>( + __builtin_msa_ilvl_h(src.buf.reg[7], src.buf.reg[6])); v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6)); v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6)); - StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>( - __builtin_msa_ilvr_d(u2, u0))); - StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>( - __builtin_msa_ilvl_d(u2, u0))); - StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>( - __builtin_msa_ilvr_d(u3, u1))); - StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>( - __builtin_msa_ilvl_d(u3, u1))); - StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>( - __builtin_msa_ilvr_d(u6, u4))); - StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>( - __builtin_msa_ilvl_d(u6, u4))); - StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>( - __builtin_msa_ilvr_d(u7, u5))); - StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>( - __builtin_msa_ilvl_d(u7, u5))); + StoreInt16x8(dst->data(row + 0, col), + reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u2, u0))); + StoreInt16x8(dst->data(row + 1, col), + reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u2, u0))); + StoreInt16x8(dst->data(row + 2, col), + reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u3, u1))); + StoreInt16x8(dst->data(row + 3, col), + reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u3, u1))); + StoreInt16x8(dst->data(row + 4, col), + reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u6, u4))); + StoreInt16x8(dst->data(row + 5, col), + reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u6, u4))); + StoreInt16x8(dst->data(row + 6, col), + reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u7, u5))); + StoreInt16x8(dst->data(row + 7, col), + reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u7, u5))); } } }; @@ -585,6 +687,391 @@ struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { } }; +// There's no way to express in C++ the desired machine code for +// StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> and +// StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType>. +// Hence, if we can, we use inline assembly, which takes advantage +// of little-endian byte order and specifics of different CPU revisions. +// Note, clang currently can't derive MSA register names from floating- +// point register names and vice versa in inline assembly. +#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) && \ + !defined(__clang__) + +// Instructions for pointer-sized operands. +#ifdef GEMMLOWP_MIPS_64 +#define GEMMLOWP_MIPS_XADDU "daddu" +#define GEMMLOWP_MIPS_XLSA "dlsa" +#else +#define GEMMLOWP_MIPS_XADDU "addu" +#define GEMMLOWP_MIPS_XLSA "lsa" +#endif + +// Stores 4 8-byte half-vectors with a stride. +inline void MipsMsaStore4x8(const RegBlockUint8<8, 4>& src, + std::uint8_t* dst_ptr, int stride) { +#if (__mips_isa_rev >= 6) + // Assembly temporaries that will be handily referred to by their names. + std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3; + v16i8 vtmp0, vtmp1; + asm volatile( + GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" + "ilvl.d %w[vtmp0], %w[src0], %w[src0]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" + "ilvl.d %w[vtmp1], %w[src1], %w[src1]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" + "sdc1 %[src0], 0(%[dst_ptr0])\n" + "sdc1 %[vtmp0], 0(%[dst_ptr1])\n" + "sdc1 %[src1], 0(%[dst_ptr2])\n" + "sdc1 %[vtmp1], 0(%[dst_ptr3])\n" + : + // Outputs. + [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), + [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), + [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1) + : + // Inputs. + [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), + [stride] "r"(stride) + : + // Clobbers. + "memory"); +#else + // Assembly temporaries that will be handily referred to by their names. + std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3; + int tmp0, tmp1, tmp2, tmp3; + asm volatile( + GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" + "copy_s.w %[tmp0], %w[src0][0]\n" + "copy_s.w %[tmp1], %w[src0][1]\n" + "copy_s.w %[tmp2], %w[src0][2]\n" + "copy_s.w %[tmp3], %w[src0][3]\n" + "swr %[tmp0], 0(%[dst_ptr0])\n" + "swl %[tmp0], 3(%[dst_ptr0])\n" + "swr %[tmp1], 4(%[dst_ptr0])\n" + "swl %[tmp1], 7(%[dst_ptr0])\n" + "swr %[tmp2], 0(%[dst_ptr1])\n" + "swl %[tmp2], 3(%[dst_ptr1])\n" + "swr %[tmp3], 4(%[dst_ptr1])\n" + "swl %[tmp3], 7(%[dst_ptr1])\n" + "copy_s.w %[tmp0], %w[src1][0]\n" + "copy_s.w %[tmp1], %w[src1][1]\n" + "copy_s.w %[tmp2], %w[src1][2]\n" + "copy_s.w %[tmp3], %w[src1][3]\n" + "swr %[tmp0], 0(%[dst_ptr2])\n" + "swl %[tmp0], 3(%[dst_ptr2])\n" + "swr %[tmp1], 4(%[dst_ptr2])\n" + "swl %[tmp1], 7(%[dst_ptr2])\n" + "swr %[tmp2], 0(%[dst_ptr3])\n" + "swl %[tmp2], 3(%[dst_ptr3])\n" + "swr %[tmp3], 4(%[dst_ptr3])\n" + "swl %[tmp3], 7(%[dst_ptr3])\n" + : + // Outputs. + [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), + [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), [tmp0] "=&r"(tmp0), + [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3) + : + // Inputs. + [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), + [stride] "r"(stride) + : + // Clobbers. + "memory"); +#endif +} + +// Stores 8 4-byte quarter-vectors with a stride. +inline void MipsMsaStore8x4(const RegBlockUint8<4, 8>& src, + std::uint8_t* dst_ptr, int stride) { +#if (__mips_isa_rev >= 6) + // Assembly temporaries that will be handily referred to by their names. + std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5, + *dst_ptr6, *dst_ptr7; + int tmp1, tmp2, tmp3; + asm volatile( + GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n" + "copy_s.w %[tmp1], %w[src0][1]\n" + "copy_s.w %[tmp2], %w[src0][2]\n" + "copy_s.w %[tmp3], %w[src0][3]\n" + "swc1 %[src0], 0(%[dst_ptr0])\n" + "sw %[tmp1], 0(%[dst_ptr1])\n" + "sw %[tmp2], 0(%[dst_ptr2])\n" + "sw %[tmp3], 0(%[dst_ptr3])\n" + "copy_s.w %[tmp1], %w[src1][1]\n" + "copy_s.w %[tmp2], %w[src1][2]\n" + "copy_s.w %[tmp3], %w[src1][3]\n" + "swc1 %[src1], 0(%[dst_ptr4])\n" + "sw %[tmp1], 0(%[dst_ptr5])\n" + "sw %[tmp2], 0(%[dst_ptr6])\n" + "sw %[tmp3], 0(%[dst_ptr7])\n" + : + // Outputs. + [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), + [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), + [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5), + [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7), + [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3) + : + // Inputs. + [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), + [stride] "r"(stride) + : + // Clobbers. + "memory"); +#else + // Assembly temporaries that will be handily referred to by their names. + std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5, + *dst_ptr6, *dst_ptr7; + int tmp0, tmp1, tmp2, tmp3; + asm volatile( + GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n" + "copy_s.w %[tmp0], %w[src0][0]\n" + "copy_s.w %[tmp1], %w[src0][1]\n" + "copy_s.w %[tmp2], %w[src0][2]\n" + "copy_s.w %[tmp3], %w[src0][3]\n" + "swr %[tmp0], 0(%[dst_ptr0])\n" + "swl %[tmp0], 3(%[dst_ptr0])\n" + "swr %[tmp1], 0(%[dst_ptr1])\n" + "swl %[tmp1], 3(%[dst_ptr1])\n" + "swr %[tmp2], 0(%[dst_ptr2])\n" + "swl %[tmp2], 3(%[dst_ptr2])\n" + "swr %[tmp3], 0(%[dst_ptr3])\n" + "swl %[tmp3], 3(%[dst_ptr3])\n" + "copy_s.w %[tmp0], %w[src1][0]\n" + "copy_s.w %[tmp1], %w[src1][1]\n" + "copy_s.w %[tmp2], %w[src1][2]\n" + "copy_s.w %[tmp3], %w[src1][3]\n" + "swr %[tmp0], 0(%[dst_ptr4])\n" + "swl %[tmp0], 3(%[dst_ptr4])\n" + "swr %[tmp1], 0(%[dst_ptr5])\n" + "swl %[tmp1], 3(%[dst_ptr5])\n" + "swr %[tmp2], 0(%[dst_ptr6])\n" + "swl %[tmp2], 3(%[dst_ptr6])\n" + "swr %[tmp3], 0(%[dst_ptr7])\n" + "swl %[tmp3], 3(%[dst_ptr7])\n" + : + // Outputs. + [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), + [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), + [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5), + [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7), + [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), + [tmp3] "=&r"(tmp3) + : + // Inputs. + [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), + [stride] "r"(stride) + : + // Clobbers. + "memory"); +#endif +} + +// Stores 8 8-byte half-vectors with a stride. +inline void MipsMsaStore8x8(const RegBlockUint8<8, 8>& src, + std::uint8_t* dst_ptr, int stride) { +#if (__mips_isa_rev >= 6) + // Assembly temporaries that will be handily referred to by their names. + std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5, + *dst_ptr6, *dst_ptr7; + v16i8 vtmp0, vtmp1, vtmp2, vtmp3; + asm volatile( + "ilvl.d %w[vtmp0], %w[src0], %w[src0]\n" + GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" + "ilvl.d %w[vtmp1], %w[src1], %w[src1]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" + "ilvl.d %w[vtmp2], %w[src2], %w[src2]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n" + "ilvl.d %w[vtmp3], %w[src3], %w[src3]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n" + "sdc1 %[src0], 0(%[dst_ptr0])\n" + "sdc1 %[vtmp0], 0(%[dst_ptr1])\n" + "sdc1 %[src1], 0(%[dst_ptr2])\n" + "sdc1 %[vtmp1], 0(%[dst_ptr3])\n" + "sdc1 %[src2], 0(%[dst_ptr4])\n" + "sdc1 %[vtmp2], 0(%[dst_ptr5])\n" + "sdc1 %[src3], 0(%[dst_ptr6])\n" + "sdc1 %[vtmp3], 0(%[dst_ptr7])\n" + : + // Outputs. + [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), + [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), + [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5), + [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7), + [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1), [vtmp2] "=&f"(vtmp2), + [vtmp3] "=&f"(vtmp3) + : + // Inputs. + [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), + [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]), + [stride] "r"(stride) + : + // Clobbers. + "memory"); +#else + // Assembly temporaries that will be handily referred to by their names. + std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5, + *dst_ptr6, *dst_ptr7; + int tmp0, tmp1, tmp2, tmp3; + asm volatile( + GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n" + GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n" + "copy_s.w %[tmp0], %w[src0][0]\n" + "copy_s.w %[tmp1], %w[src0][1]\n" + "copy_s.w %[tmp2], %w[src0][2]\n" + "copy_s.w %[tmp3], %w[src0][3]\n" + "swr %[tmp0], 0(%[dst_ptr0])\n" + "swl %[tmp0], 3(%[dst_ptr0])\n" + "swr %[tmp1], 4(%[dst_ptr0])\n" + "swl %[tmp1], 7(%[dst_ptr0])\n" + "swr %[tmp2], 0(%[dst_ptr1])\n" + "swl %[tmp2], 3(%[dst_ptr1])\n" + "swr %[tmp3], 4(%[dst_ptr1])\n" + "swl %[tmp3], 7(%[dst_ptr1])\n" + "copy_s.w %[tmp0], %w[src1][0]\n" + "copy_s.w %[tmp1], %w[src1][1]\n" + "copy_s.w %[tmp2], %w[src1][2]\n" + "copy_s.w %[tmp3], %w[src1][3]\n" + "swr %[tmp0], 0(%[dst_ptr2])\n" + "swl %[tmp0], 3(%[dst_ptr2])\n" + "swr %[tmp1], 4(%[dst_ptr2])\n" + "swl %[tmp1], 7(%[dst_ptr2])\n" + "swr %[tmp2], 0(%[dst_ptr3])\n" + "swl %[tmp2], 3(%[dst_ptr3])\n" + "swr %[tmp3], 4(%[dst_ptr3])\n" + "swl %[tmp3], 7(%[dst_ptr3])\n" + "copy_s.w %[tmp0], %w[src2][0]\n" + "copy_s.w %[tmp1], %w[src2][1]\n" + "copy_s.w %[tmp2], %w[src2][2]\n" + "copy_s.w %[tmp3], %w[src2][3]\n" + "swr %[tmp0], 0(%[dst_ptr4])\n" + "swl %[tmp0], 3(%[dst_ptr4])\n" + "swr %[tmp1], 4(%[dst_ptr4])\n" + "swl %[tmp1], 7(%[dst_ptr4])\n" + "swr %[tmp2], 0(%[dst_ptr5])\n" + "swl %[tmp2], 3(%[dst_ptr5])\n" + "swr %[tmp3], 4(%[dst_ptr5])\n" + "swl %[tmp3], 7(%[dst_ptr5])\n" + "copy_s.w %[tmp0], %w[src3][0]\n" + "copy_s.w %[tmp1], %w[src3][1]\n" + "copy_s.w %[tmp2], %w[src3][2]\n" + "copy_s.w %[tmp3], %w[src3][3]\n" + "swr %[tmp0], 0(%[dst_ptr6])\n" + "swl %[tmp0], 3(%[dst_ptr6])\n" + "swr %[tmp1], 4(%[dst_ptr6])\n" + "swl %[tmp1], 7(%[dst_ptr6])\n" + "swr %[tmp2], 0(%[dst_ptr7])\n" + "swl %[tmp2], 3(%[dst_ptr7])\n" + "swr %[tmp3], 4(%[dst_ptr7])\n" + "swl %[tmp3], 7(%[dst_ptr7])\n" + : + // Outputs. + [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), + [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), + [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5), + [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7), + [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), + [tmp3] "=&r"(tmp3) + : + // Inputs. + [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), + [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]), + [stride] "r"(stride) + : + // Clobbers. + "memory"); +#endif +} + +#undef GEMMLOWP_MIPS_XADDU +#undef GEMMLOWP_MIPS_XLSA + +// Transposes a column-major 8x4 block for storage into a row-major matrix. +inline RegBlockUint8<4, 8> Transpose(const RegBlockUint8<8, 4>& src) { + v16i8 tmp0 = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]); + v16i8 tmp1 = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]); + RegBlockUint8<4, 8> result; + result.buf.reg[0] = __builtin_msa_ilvr_b(tmp1, tmp0); + result.buf.reg[1] = __builtin_msa_ilvl_b(tmp1, tmp0); + return result; +} + +inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) { + v16i8 tmp0[4]; + tmp0[0] = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]); + tmp0[1] = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]); + tmp0[2] = __builtin_msa_ilvr_b(src.buf.reg[3], src.buf.reg[2]); + tmp0[3] = __builtin_msa_ilvl_b(src.buf.reg[3], src.buf.reg[2]); + v16i8 tmp1[4]; + tmp1[0] = __builtin_msa_ilvr_b(tmp0[1], tmp0[0]); + tmp1[1] = __builtin_msa_ilvl_b(tmp0[1], tmp0[0]); + tmp1[2] = __builtin_msa_ilvr_b(tmp0[3], tmp0[2]); + tmp1[3] = __builtin_msa_ilvl_b(tmp0[3], tmp0[2]); + RegBlockUint8<8, 8> result; + result.buf.reg[0] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w( + reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0]))); + result.buf.reg[1] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w( + reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0]))); + result.buf.reg[2] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w( + reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1]))); + result.buf.reg[3] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w( + reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1]))); + return result; +} + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { + static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, + int col) { + if (DstType::kOrder == MapOrder::ColMajor) { + std::uint8_t* dst_ptr = dst->data(row, col); + int col_stride = dst->cols_stride(); + MipsMsaStore4x8(src, dst_ptr, col_stride); + } else { + const auto& block = Transpose(src); + std::uint8_t* dst_ptr = dst->data(row, col); + int row_stride = dst->rows_stride(); + MipsMsaStore8x4(block, dst_ptr, row_stride); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { + static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, + int col) { + const auto& block = + (DstType::kOrder == MapOrder::ColMajor) ? src : Transpose(src); + std::uint8_t* dst_ptr = dst->data(row, col); + int stride = dst->stride(); + MipsMsaStore8x8(block, dst_ptr, stride); + } +}; + +#else + template <typename DstType> struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, @@ -617,6 +1104,8 @@ struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { } }; +#endif // Endianness, compiler. + } // namespace gemmlowp #endif // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_ diff --git a/internal/output_neon.h b/internal/output_neon.h index 911fed0..52ea1bc 100644 --- a/internal/output_neon.h +++ b/internal/output_neon.h @@ -108,6 +108,90 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, }; template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, + RegBufferInt32<4>> { + typedef RegBufferInt32<4> InputType; + typedef RegBufferInt8<4> OutputType; + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x4_t res_16 = vqmovn_s32(input.reg[0]); + int8x8_t res_8 = vqmovn_s16(vcombine_s16(res_16, res_16)); + output.reg[0] = vget_lane_s32(vreinterpret_s32_s8(res_8), 0); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, + RegBufferInt32<8>> { + typedef RegBufferInt32<8> InputType; + typedef RegBufferInt8<8> OutputType; + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16 = + vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); + output.reg[0] = vqmovn_s16(res_16); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, + RegBufferInt32<16>> { + typedef RegBufferInt32<16> InputType; + typedef RegBufferInt8<16> OutputType; + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16_0 = + vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); + int16x8_t res_16_1 = + vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); + output.reg[0] = vqmovn_s16(res_16_0); + output.reg[1] = vqmovn_s16(res_16_1); + return output; + } +}; + +template <> +struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, + RegBufferInt32<32>> { + typedef RegBufferInt32<32> InputType; + typedef RegBufferInt8<32> OutputType; + + typedef OutputStageSaturatingCastToInt8 OutputStage; + + OutputStageEvalBufferImpl(const OutputStage&) {} + + OutputType Eval(InputType input) const { + OutputType output; + int16x8_t res_16[4]; + for (int i = 0; i < 4; i++) { + res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]), + vqmovn_s32(input.reg[2 * i + 1])); + } + for (int i = 0; i < 4; i++) { + output.reg[i] = vqmovn_s16(res_16[i]); + } + return output; + } +}; + +template <> struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, RegBufferInt32<4>> { typedef RegBufferInt32<4> InputType; @@ -556,8 +640,8 @@ struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]); } } else { + int row_stride = dst->rows_stride(); for (int i = 0; i < 4; i++) { - int row_stride = dst->rows_stride(); std::uint8_t* col_ptr = dst_ptr + i; vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); @@ -623,6 +707,153 @@ struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { }; template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt8<4, 1>, DstType> { + static void Run(const RegBlockInt8<4, 1>& src, DstType* dst, int row, + int col) { + const std::int32_t src_reg = src.buf.reg[0]; + for (int i = 0; i < 4; i++) { + *dst->data(row + i, col) = (src_reg >> (8 * i)); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt8<1, 4>, DstType> { + static void Run(const RegBlockInt8<1, 4>& src, DstType* dst, int row, + int col) { + for (int i = 0; i < 4; i++) { + *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt8<8, 1>, DstType> { + static void Run(const RegBlockInt8<8, 1>& src, DstType* dst, int row, + int col) { + std::int8_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + vst1_s8(dst_ptr, src.buf.reg[0]); + } else { + const int row_stride = dst->rows_stride(); + vst1_lane_s8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); + vst1_lane_s8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); + vst1_lane_s8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); + vst1_lane_s8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); + vst1_lane_s8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4); + vst1_lane_s8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5); + vst1_lane_s8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6); + vst1_lane_s8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt8<4, 4>, DstType> { + static void Run(const RegBlockInt8<4, 4>& src, DstType* dst, int row, + int col) { + std::int8_t* dst_ptr = dst->data(row, col); + const int row_stride = dst->rows_stride(); + const int col_stride = dst->cols_stride(); + for (int i = 0; i < 2; i++) { + vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 0); + vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 1); + vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 2); + vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride, + src.buf.reg[i], 3); + vst1_lane_s8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 4); + vst1_lane_s8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 5); + vst1_lane_s8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 6); + vst1_lane_s8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride, + src.buf.reg[i], 7); + } + } +}; + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt8<8, 4>, DstType> { + static void Run(const RegBlockInt8<8, 4>& src, DstType* dst, int row, + int col) { + std::int8_t* dst_ptr = dst->data(row, col); + if (DstType::kOrder == MapOrder::ColMajor) { + int col_stride = dst->cols_stride(); + for (int i = 0; i < 4; i++) { + vst1_s8(dst_ptr + i * col_stride, src.buf.reg[i]); + } + } else { + int row_stride = dst->rows_stride(); + for (int i = 0; i < 4; i++) { + std::int8_t* col_ptr = dst_ptr + i; + vst1_lane_s8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); + vst1_lane_s8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); + vst1_lane_s8(col_ptr + 2 * row_stride, src.buf.reg[i], 2); + vst1_lane_s8(col_ptr + 3 * row_stride, src.buf.reg[i], 3); + vst1_lane_s8(col_ptr + 4 * row_stride, src.buf.reg[i], 4); + vst1_lane_s8(col_ptr + 5 * row_stride, src.buf.reg[i], 5); + vst1_lane_s8(col_ptr + 6 * row_stride, src.buf.reg[i], 6); + vst1_lane_s8(col_ptr + 7 * row_stride, src.buf.reg[i], 7); + } + } + } +}; + +inline RegBlockInt8<8, 8> Transpose(const RegBlockInt8<8, 8>& src) { + int8x8x2_t a[4]; + a[0] = vtrn_s8(src.buf.reg[0], src.buf.reg[1]); + a[1] = vtrn_s8(src.buf.reg[2], src.buf.reg[3]); + a[2] = vtrn_s8(src.buf.reg[4], src.buf.reg[5]); + a[3] = vtrn_s8(src.buf.reg[6], src.buf.reg[7]); + int16x4x2_t b[4]; + b[0] = vtrn_s16(vreinterpret_s16_s8(a[0].val[0]), + vreinterpret_s16_s8(a[1].val[0])); + b[1] = vtrn_s16(vreinterpret_s16_s8(a[0].val[1]), + vreinterpret_s16_s8(a[1].val[1])); + b[2] = vtrn_s16(vreinterpret_s16_s8(a[2].val[0]), + vreinterpret_s16_s8(a[3].val[0])); + b[3] = vtrn_s16(vreinterpret_s16_s8(a[2].val[1]), + vreinterpret_s16_s8(a[3].val[1])); + int32x2x2_t c[4]; + c[0] = vtrn_s32(vreinterpret_s32_s16(b[0].val[0]), + vreinterpret_s32_s16(b[2].val[0])); + c[1] = vtrn_s32(vreinterpret_s32_s16(b[1].val[0]), + vreinterpret_s32_s16(b[3].val[0])); + c[2] = vtrn_s32(vreinterpret_s32_s16(b[0].val[1]), + vreinterpret_s32_s16(b[2].val[1])); + c[3] = vtrn_s32(vreinterpret_s32_s16(b[1].val[1]), + vreinterpret_s32_s16(b[3].val[1])); + RegBlockInt8<8, 8> result; + result.buf.reg[0] = vreinterpret_s8_s32(c[0].val[0]); + result.buf.reg[1] = vreinterpret_s8_s32(c[1].val[0]); + result.buf.reg[2] = vreinterpret_s8_s32(c[2].val[0]); + result.buf.reg[3] = vreinterpret_s8_s32(c[3].val[0]); + result.buf.reg[4] = vreinterpret_s8_s32(c[0].val[1]); + result.buf.reg[5] = vreinterpret_s8_s32(c[1].val[1]); + result.buf.reg[6] = vreinterpret_s8_s32(c[2].val[1]); + result.buf.reg[7] = vreinterpret_s8_s32(c[3].val[1]); + return result; +} + +template <typename DstType> +struct StoreFinalOutputImpl<RegBlockInt8<8, 8>, DstType> { + static void Run(const RegBlockInt8<8, 8>& src, DstType* dst, int row, + int col) { + const auto& block = + DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); + std::int8_t* dst_ptr = dst->data(row, col); + int stride = dst->stride(); + for (int i = 0; i < 8; i++) { + vst1_s8(dst_ptr + i * stride, block.buf.reg[i]); + } + } +}; + +template <typename DstType> struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> { static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row, int col) { diff --git a/internal/pack.h b/internal/pack.h index cb4b93a..7c43d6e 100644 --- a/internal/pack.h +++ b/internal/pack.h @@ -72,6 +72,10 @@ class PackedSideBlock { pos_ += n * KernelSideFormat::Cell::kSize; } + // TODO(suharshs): The datatype can now be int8 as well. We could introduce a + // new int8 current_data impl as well. This change would propagate to all pack + // impls and the Kernel::Run API, which all assume uint8. For now we leave + // this as-is pending future refactor. const std::uint8_t* current_data() const { return allocator_->GetPointer<std::uint8_t>(data_handle_) + pos_; } @@ -208,6 +212,7 @@ class PackingRegisterBlockBase { public: typedef typename PackedSideBlock::KernelSideFormat KernelSideFormat; typedef typename KernelSideFormat::Cell CellFormat; + typedef typename KernelSideFormat::InputScalar KernelInputScalar; typedef typename KernelSideFormat::Scalar KernelScalar; static const int kCells = KernelSideFormat::kCells; static const int kCellWidth = CellFormat::kWidth; @@ -216,7 +221,7 @@ class PackingRegisterBlockBase { static const int kCellSize = CellFormat::kSize; static const SideMapOrder kSrcOrder = SrcMapType::kOrder; static const int kZeroPointInputValue = - ZeroPointInputValue<KernelScalar>::kValue; + ZeroPointInputValue<KernelInputScalar, KernelScalar>::kValue; PackingRegisterBlockBase() : complete_src_(nullptr, 0, 0, 0) {} @@ -233,7 +238,7 @@ class PackingRegisterBlockBase { std::uint8_t buf_[kKernelWidth * kRegisterSize]; public: - // Selects a block if in-place source data that's already a complete block + // Selects a block if in-place source data that's already a complete block. void UseCompleteSrcInPlace(const SrcMapType& src) { complete_src_ = src; } // Copies an incomplete block of source data into a local temporary // complete block by zero-extending it. @@ -249,7 +254,10 @@ class PackingRegisterBlockBase { memcpy(buf_ + d * kKernelWidth, src.data(0, d), src.width()); } } - complete_src_ = SrcMapType(buf_, kKernelWidth, kRegisterSize); + + // Since the KernelInputScalar type may not be uint8, we need to cast buf_. + complete_src_ = SrcMapType(reinterpret_cast<KernelInputScalar*>(buf_), + kKernelWidth, kRegisterSize); } // Packs a complete block into the destination. This is the most // critical part and the part that we most typically want to @@ -340,7 +348,7 @@ class PackSideBlockImpl { } } - // Prefetches the data that will be read by PackL1 + // Prefetches the data that will be read by PackL1. void PrefetchL1(int start_width, int width, int start_depth, int depth) { if (SrcMapType::kOrder == SideMapOrder::WidthMajor) { for (int d = 0; d < depth; d += kDefaultCacheLineSize) { @@ -394,7 +402,7 @@ class PackSideBlockImpl { const SrcMapType& src_map_; }; -// Packs a block of the input LHS matrix, into a PackedSideBlock +// Packs a block of the input LHS matrix, into a PackedSideBlock. template <typename PackedSideBlock, typename MatrixMapType> void PackLhs(PackedSideBlock* dst, const MatrixMapType& src) { ScopedProfilingLabel label("pack LHS"); @@ -409,7 +417,7 @@ void PackLhs(PackedSideBlock* dst, const MatrixMapType& src) { impl.PackL2(); } -// Packs a block of the input RHS matrix, into a PackedSideBlock +// Packs a block of the input RHS matrix, into a PackedSideBlock. template <typename PackedSideBlock, typename MatrixMapType> void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) { ScopedProfilingLabel label("pack RHS"); @@ -430,6 +438,8 @@ void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) { #include "pack_neon.h" #elif defined(GEMMLOWP_SSE4) #include "pack_sse.h" +#elif defined(GEMMLOWP_AVX2) +#include "pack_avx.h" #elif defined(GEMMLOWP_MSA) #include "pack_msa.h" #endif diff --git a/internal/pack_avx.h b/internal/pack_avx.h new file mode 100644 index 0000000..1ef5ce1 --- /dev/null +++ b/internal/pack_avx.h @@ -0,0 +1,282 @@ +// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// pack_avx.h: optimized AVX specializations of the templates in pack.h. + +#ifndef GEMMLOWP_INTERNAL_PACK_AVX_H_ +#define GEMMLOWP_INTERNAL_PACK_AVX_H_ + +#include <immintrin.h> +#include "pack.h" + +namespace gemmlowp { + +// TODO: Add DepthMajorUint8SideMap + +typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> + WidthMajorUint8SideMap; + +template <int Cells> +using WidthMajorSideFormatNCells4x2 = + KernelSideFormat<CellFormat<8, 2, CellOrder::WidthMajor>, Cells>; + +template <int Cells> +class PackingRegisterBlock< + WidthMajorUint8SideMap, + PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> + : public PackingRegisterBlockBase< + WidthMajorUint8SideMap, + PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> { + public: + typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + + void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) { + std::uint8_t *dst_ptr = dst->current_data(); + const int width_stride = this->complete_src_.width_stride(); + int depth_step = 16; + + __m256i one = _mm256_set1_epi16(1); + for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; + cell_start_depth += depth_step) { + for (int cell_start_width = 0; cell_start_width < kKernelWidth; + cell_start_width += kCellWidth) { + std::int32_t *cell_sums_of_each_slice_ptr = + dst->sums_of_each_slice() + start_width + cell_start_width; + const std::uint8_t *src_data = + this->complete_src_.data(cell_start_width, cell_start_depth); + + __m128i xmm1 = + _mm_loadu_si128(reinterpret_cast<const __m128i *>(&src_data[0])); + __m128i xmm2 = _mm_loadu_si128( + reinterpret_cast<const __m128i *>(&src_data[1 * width_stride])); + __m128i xmm3 = _mm_loadu_si128( + reinterpret_cast<const __m128i *>(&src_data[2 * width_stride])); + __m128i xmm4 = _mm_loadu_si128( + reinterpret_cast<const __m128i *>(&src_data[3 * width_stride])); + __m128i xmm5 = _mm_loadu_si128( + reinterpret_cast<const __m128i *>(&src_data[4 * width_stride])); + __m128i xmm6 = _mm_loadu_si128( + reinterpret_cast<const __m128i *>(&src_data[5 * width_stride])); + __m128i xmm7 = _mm_loadu_si128( + reinterpret_cast<const __m128i *>(&src_data[6 * width_stride])); + __m128i xmm8 = _mm_loadu_si128( + reinterpret_cast<const __m128i *>(&src_data[7 * width_stride])); + + __m256i ymm1 = _mm256_set_m128i(xmm5, xmm1); + __m256i ymm2 = _mm256_set_m128i(xmm6, xmm2); + __m256i ymm3 = _mm256_set_m128i(xmm7, xmm3); + __m256i ymm4 = _mm256_set_m128i(xmm8, xmm4); + + __m256i ymm5 = _mm256_unpacklo_epi16(ymm1, ymm2); + __m256i ymm6 = _mm256_unpacklo_epi16(ymm3, ymm4); + + __m256i ymm9 = _mm256_unpackhi_epi16(ymm1, ymm2); + __m256i ymm10 = _mm256_unpackhi_epi16(ymm3, ymm4); + + __m256i ymm7 = _mm256_unpacklo_epi32(ymm5, ymm6); + __m256i ymm8 = _mm256_unpackhi_epi32(ymm5, ymm6); + + __m256i ymm13 = _mm256_unpacklo_epi32(ymm9, ymm10); + __m256i ymm14 = _mm256_unpackhi_epi32(ymm9, ymm10); + + __m256i ymm11 = _mm256_permute4x64_epi64(ymm7, 0xd8); + __m256i ymm12 = _mm256_permute4x64_epi64(ymm8, 0xd8); + + __m256i ymm15 = _mm256_permute4x64_epi64(ymm13, 0xd8); + __m256i ymm16 = _mm256_permute4x64_epi64(ymm14, 0xd8); + + __m128i xmm9 = _mm256_castsi256_si128(ymm11); + __m128i xmm10 = _mm256_castsi256_si128(ymm12); + __m128i xmm11 = _mm256_extracti128_si256(ymm11, 1); + __m128i xmm12 = _mm256_extracti128_si256(ymm12, 1); + + xmm1 = _mm256_castsi256_si128(ymm15); + xmm2 = _mm256_castsi256_si128(ymm16); + xmm3 = _mm256_extracti128_si256(ymm15, 1); + xmm4 = _mm256_extracti128_si256(ymm16, 1); + + _mm_storeu_si128(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9); + _mm_storeu_si128( + reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm11); + _mm_storeu_si128( + reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]), + xmm10); + _mm_storeu_si128( + reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]), + xmm12); + _mm_storeu_si128( + reinterpret_cast<__m128i *>(&dst_ptr[4 * kCellSize * kCells]), + xmm1); + _mm_storeu_si128( + reinterpret_cast<__m128i *>(&dst_ptr[5 * kCellSize * kCells]), + xmm3); + + _mm_storeu_si128( + reinterpret_cast<__m128i *>(&dst_ptr[6 * kCellSize * kCells]), + xmm2); + _mm_storeu_si128( + reinterpret_cast<__m128i *>(&dst_ptr[7 * kCellSize * kCells]), + xmm4); + + ymm6 = _mm256_cvtepu8_epi16(xmm9); + ymm7 = _mm256_madd_epi16(ymm6, one); + __m256i sums_of_each_slice_xmm = _mm256_loadu_si256( + reinterpret_cast<const __m256i *>(&cell_sums_of_each_slice_ptr[0])); + sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); + + ymm6 = _mm256_cvtepu8_epi16(xmm11); + ymm7 = _mm256_madd_epi16(ymm6, one); + sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); + + ymm6 = _mm256_cvtepu8_epi16(xmm10); + ymm7 = _mm256_madd_epi16(ymm6, one); + sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); + + ymm6 = _mm256_cvtepu8_epi16(xmm12); + ymm7 = _mm256_madd_epi16(ymm6, one); + sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); + + ymm6 = _mm256_cvtepu8_epi16(xmm1); + ymm7 = _mm256_madd_epi16(ymm6, one); + sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); + + ymm6 = _mm256_cvtepu8_epi16(xmm3); + ymm7 = _mm256_madd_epi16(ymm6, one); + sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); + + ymm6 = _mm256_cvtepu8_epi16(xmm2); + ymm7 = _mm256_madd_epi16(ymm6, one); + sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); + + ymm6 = _mm256_cvtepu8_epi16(xmm4); + ymm7 = _mm256_madd_epi16(ymm6, one); + sums_of_each_slice_xmm = _mm256_add_epi32(sums_of_each_slice_xmm, ymm7); + + _mm256_storeu_si256( + reinterpret_cast<__m256i *>(&cell_sums_of_each_slice_ptr[0]), + sums_of_each_slice_xmm); + dst_ptr += kCellSize; + } + dst_ptr += 7 * kCellSize * kCells; + } + dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); + } +}; + +// Pack format for 4x2 rhs format +template <int Cells> +using RhsWidthMajorSideFormatNCells4x2 = + KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; + +template <int Cells> +class PackingRegisterBlock< + WidthMajorUint8SideMap, + PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> + : public PackingRegisterBlockBase< + WidthMajorUint8SideMap, + PackedSideBlock<RhsWidthMajorSideFormatNCells4x2<Cells>>> { + public: + typedef RhsWidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + + void Pack(PackedSideBlock<KernelSideFormat> *dst, int start_width) { + std::uint8_t *dst_ptr = dst->current_data(); + const int width_stride = this->complete_src_.width_stride(); + int depth_step = 8; + + __m128i one = _mm_set1_epi16(1); + for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; + cell_start_depth += depth_step) { + for (int cell_start_width = 0; cell_start_width < kKernelWidth; + cell_start_width += kCellWidth) { + std::int32_t *cell_sums_of_each_slice_ptr = + dst->sums_of_each_slice() + start_width + cell_start_width; + const std::uint8_t *src_data = + this->complete_src_.data(cell_start_width, cell_start_depth); + + __m128i xmm1 = + _mm_loadl_epi64(reinterpret_cast<const __m128i *>(&src_data[0])); + __m128i xmm2 = _mm_loadl_epi64( + reinterpret_cast<const __m128i *>(&src_data[1 * width_stride])); + __m128i xmm3 = _mm_loadl_epi64( + reinterpret_cast<const __m128i *>(&src_data[2 * width_stride])); + __m128i xmm4 = _mm_loadl_epi64( + reinterpret_cast<const __m128i *>(&src_data[3 * width_stride])); + + __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2); + __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31); + + __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4); + __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80); + + __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc); + __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc); + + _mm_storel_epi64(reinterpret_cast<__m128i *>(&dst_ptr[0]), xmm9); + _mm_storel_epi64( + reinterpret_cast<__m128i *>(&dst_ptr[kCellSize * kCells]), xmm10); + + __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee); + __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee); + + _mm_storel_epi64( + reinterpret_cast<__m128i *>(&dst_ptr[2 * kCellSize * kCells]), + xmm11); + _mm_storel_epi64( + reinterpret_cast<__m128i *>(&dst_ptr[3 * kCellSize * kCells]), + xmm12); + + xmm1 = _mm_cvtepu8_epi16(xmm9); + xmm2 = _mm_madd_epi16(xmm1, one); + __m128i sums_of_each_slice_xmm = _mm_loadu_si128( + reinterpret_cast<const __m128i *>(&cell_sums_of_each_slice_ptr[0])); + sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); + + xmm1 = _mm_cvtepu8_epi16(xmm10); + xmm2 = _mm_madd_epi16(xmm1, one); + sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); + + xmm1 = _mm_cvtepu8_epi16(xmm11); + xmm2 = _mm_madd_epi16(xmm1, one); + sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); + + xmm1 = _mm_cvtepu8_epi16(xmm12); + xmm2 = _mm_madd_epi16(xmm1, one); + sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); + + _mm_storeu_si128( + reinterpret_cast<__m128i *>(&cell_sums_of_each_slice_ptr[0]), + sums_of_each_slice_xmm); + dst_ptr += kCellSize; + } + dst_ptr += 3 * kCellSize * kCells; + } + dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); + } +}; + +} // namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_PACK_AVX_H_ diff --git a/internal/pack_msa.h b/internal/pack_msa.h index fba8a0f..4072229 100644 --- a/internal/pack_msa.h +++ b/internal/pack_msa.h @@ -348,6 +348,84 @@ class PackingRegisterBlock< } }; +template <int Width> +using Int8FastKernelFormat = + KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>; + +template <int Width> +class PackingRegisterBlock<WidthMajorUint8SideMap, + PackedSideBlock<Int8FastKernelFormat<Width>>> + : public PackingRegisterBlockBase< + WidthMajorUint8SideMap, + PackedSideBlock<Int8FastKernelFormat<Width>>> { + public: + static_assert(Width == 2 || Width == 4, ""); + typedef Int8FastKernelFormat<Width> KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + + void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { + std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width; + std::uint8_t* dst_ptr = dst->current_data(); + const std::uint8_t* const src_ptr = this->complete_src_.data(); + const int stride = this->complete_src_.stride(); + // Load source WidthMajor data. + v16i8 src_lines[Width]; + for (int i = 0; i < Width; i++) { + src_lines[i] = __builtin_msa_ld_b( + const_cast<std::uint8_t*>(src_ptr + i * stride), 0); + } + for (int i = 0; i < Width; i++) { + // Subtract 128 by inverting bit 7. + src_lines[i] = reinterpret_cast<v16i8>( + __builtin_msa_bnegi_b(reinterpret_cast<v16u8>(src_lines[i]), 7)); + } + for (int i = 0; i < Width; i++) { + __builtin_msa_st_b(src_lines[i], dst_ptr + 16 * i, 0); + } + v8i16 sums2[Width]; + for (int i = 0; i < Width; i++) { + sums2[i] = __builtin_msa_hadd_s_h(src_lines[i], src_lines[i]); + } + v4i32 sums4_wide[Width]; + for (int i = 0; i < Width; i++) { + sums4_wide[i] = __builtin_msa_hadd_s_w(sums2[i], sums2[i]); + } + v8i16 sums4[Width / 2]; + for (int i = 0; i < Width / 2; i++) { + sums4[i] = __builtin_msa_pckev_h( + reinterpret_cast<v8i16>(sums4_wide[2 * i + 1]), + reinterpret_cast<v8i16>(sums4_wide[2 * i])); + } + v4i32 sums8_wide[Width / 2]; + for (int i = 0; i < Width / 2; i++) { + sums8_wide[i] = __builtin_msa_hadd_s_w(sums4[i], sums4[i]); + } + if (Width == 4) { + v4i32 sum = __builtin_msa_ld_w(const_cast<std::int32_t*>(sums_ptr), 0); + v8i16 sums8 = __builtin_msa_pckev_h( + reinterpret_cast<v8i16>(sums8_wide[1]), + reinterpret_cast<v8i16>(sums8_wide[0])); + v4i32 sums16 = __builtin_msa_hadd_s_w(sums8, sums8); + sum = __builtin_msa_addv_w(sum, sums16); + __builtin_msa_st_w(sum, sums_ptr, 0); + } else { + assert(Width == 2); + std::int32_t sum[2] = { sums_ptr[0], sums_ptr[1] }; + v2i64 sums16 = __builtin_msa_hadd_s_d(sums8_wide[0], sums8_wide[0]); + sum[0] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 0); + sum[1] += __builtin_msa_copy_s_w(reinterpret_cast<v4i32>(sums16), 2); + sums_ptr[0] = sum[0]; + sums_ptr[1] = sum[1]; + } + dst->seek_forward_n_cells(1); + } +}; + } // namespace gemmlowp #endif // GEMMLOWP_INTERNAL_PACK_MSA_H_ diff --git a/internal/pack_neon.h b/internal/pack_neon.h index 2b08464..f113d9e 100644 --- a/internal/pack_neon.h +++ b/internal/pack_neon.h @@ -26,6 +26,9 @@ namespace gemmlowp { typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> WidthMajorUint8SideMap; +typedef SideMap<const std::int8_t, SideMapOrder::WidthMajor> + WidthMajorInt8SideMap; + template <int Cells> using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>; @@ -315,6 +318,67 @@ class PackingRegisterBlock<WidthMajorUint8SideMap, } }; +template <int Width> +using Int8InputsFastKernelFormat = + KernelSideFormatInt8Inputs<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>; + +// Same as above, but for int8 inputs, avoiding the uint8 -> int8 conversion. +template <int Width> +class PackingRegisterBlock<WidthMajorInt8SideMap, + PackedSideBlock<Int8InputsFastKernelFormat<Width>>> + : public PackingRegisterBlockBase< + WidthMajorInt8SideMap, + PackedSideBlock<Int8InputsFastKernelFormat<Width>>> { + public: + static_assert(Width == 2 || Width == 4, ""); + typedef Int8InputsFastKernelFormat<Width> KernelSideFormat; + typedef typename KernelSideFormat::Cell CellFormat; + static const int kCells = KernelSideFormat::kCells; + static const int kCellWidth = CellFormat::kWidth; + static const int kKernelWidth = CellFormat::kWidth * kCells; + static const int kCellDepth = CellFormat::kDepth; + static const int kCellSize = CellFormat::kSize; + + void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { + std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width; + std::int8_t* dst_ptr = reinterpret_cast<std::int8_t*>(dst->current_data()); + const std::int8_t* const src_ptr = this->complete_src_.data(); + const int stride = this->complete_src_.stride(); + // Load source WidthMajor data + int8x16_t src_lines[Width]; + for (int i = 0; i < Width; i++) { + src_lines[i] = vld1q_s8(src_ptr + i * stride); + } + for (int i = 0; i < Width; i++) { + vst1q_s8(dst_ptr + 16 * i, src_lines[i]); + } + int16x8_t sums2[Width]; + for (int i = 0; i < Width; i++) { + const int8x8_t lo = vget_low_s8(src_lines[i]); + const int8x8_t hi = vget_high_s8(src_lines[i]); + sums2[i] = vaddl_s8(lo, hi); + } + int16x8_t sums4[Width / 2]; + for (int i = 0; i < Width / 2; i++) { + sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]); + } + if (Width == 4) { + int32x4_t sum = vld1q_s32(sums_ptr); + int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]); + sum = vpadalq_s16(sum, sums8); + vst1q_s32(sums_ptr, sum); + } else { + assert(Width == 2); + int32x2_t sum = vld1_s32(sums_ptr); + int16x4_t sums8 = + vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0])); + sum = vpadal_s16(sum, sums8); + vst1_s32(sums_ptr, sum); + } + dst->seek_forward_n_cells(1); + } +}; + } // namespace gemmlowp #endif // GEMMLOWP_INTERNAL_PACK_NEON_H_ diff --git a/internal/platform.h b/internal/platform.h index 1114767..ab71414 100644 --- a/internal/platform.h +++ b/internal/platform.h @@ -18,6 +18,7 @@ #define GEMMLOWP_INTERNAL_PLATFORM_H_ #ifdef _WIN32 +#include <malloc.h> #include <windows.h> #else #include <stdlib.h> @@ -71,8 +72,8 @@ inline int GetHardwareConcurrency(int max_threads) { inline double real_time_in_seconds() { __int64 wintime; GetSystemTimeAsFileTime((FILETIME *)&wintime); - wintime -= 116444736000000000i64; // 1jan1601 to 1jan1970 - return wintime / 10000000i64 + wintime % 10000000i64 * 100 * 1e-9; + wintime -= 116444736000000000LL; // 1jan1601 to 1jan1970 + return wintime / 10000000LL + wintime % 10000000LL * 100 * 1e-9; } #else diff --git a/internal/simd_wrappers.h b/internal/simd_wrappers.h index d9721c9..4e4cce8 100644 --- a/internal/simd_wrappers.h +++ b/internal/simd_wrappers.h @@ -105,10 +105,12 @@ struct FlipLhsRhs { using FlippedRhsType = RhsType; static const FlippedLhsType& FlippedLhs(const LhsType& lhs, const RhsType& rhs) { + (void)rhs; return lhs; } static const FlippedRhsType& FlippedRhs(const LhsType& lhs, const RhsType& rhs) { + (void)lhs; return rhs; } }; @@ -119,10 +121,12 @@ struct FlipLhsRhs<LhsType, RhsType, true> { using FlippedRhsType = LhsType; static const FlippedLhsType& FlippedLhs(const LhsType& lhs, const RhsType& rhs) { + (void)lhs; return rhs; } static const FlippedRhsType& FlippedRhs(const LhsType& lhs, const RhsType& rhs) { + (void)rhs; return lhs; } }; @@ -192,6 +196,153 @@ typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd( } template <typename Lhs, typename Rhs> +struct BroadcastShiftLeftImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = + ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template <typename Lhs, typename Rhs> +typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft( + const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs<Lhs, Rhs>; + return BroadcastShiftLeftImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template <typename Lhs, typename Rhs> +struct BroadcastSaturatingRoundingDoublingHighMulImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template <typename Lhs, typename Rhs> +typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type +BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs<Lhs, Rhs>; + return BroadcastSaturatingRoundingDoublingHighMulImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template <typename Lhs, typename Rhs> +struct BroadcastRoundingDivideByPOTImpl { + using ResultBlockType = + typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; + static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { + ResultBlockType result; + static constexpr int Rows = ResultBlockType::kRows; + static constexpr int Cols = ResultBlockType::kCols; + static constexpr int LhsRows = Lhs::kRows; + static constexpr int LhsCols = Lhs::kCols; + static constexpr int RhsRows = Rhs::kRows; + static constexpr int RhsCols = Rhs::kCols; + + static_assert(LhsRows == Rows || LhsRows == 1, ""); + static_assert(RhsRows == Rows || RhsRows == 1, ""); + static_assert(LhsCols == Cols || LhsCols == 1, ""); + static_assert(RhsCols == Cols || RhsCols == 1, ""); + static_assert(ResultBlockType::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Lhs::kRegisterLanes == 1, + "This path is only for scalar values"); + static_assert(Rhs::kRegisterLanes == 1, + "This path is only for scalar values"); + + for (int c = 0; c < Cols; c++) { + const int lhs_c = LhsCols == Cols ? c : 0; + const int rhs_c = RhsCols == Cols ? c : 0; + for (int r = 0; r < Rows; r++) { + const int lhs_r = LhsRows == Rows ? r : 0; + const int rhs_r = RhsRows == Rows ? r : 0; + result.buf.reg[r + c * Rows] = + RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows], + rhs.buf.reg[rhs_r + rhs_c * RhsRows]); + } + } + return result; + } +}; + +template <typename Lhs, typename Rhs> +typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type +BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) { + using Flip = FlipLhsRhs<Lhs, Rhs>; + return BroadcastRoundingDivideByPOTImpl< + typename Flip::FlippedLhsType, + typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), + Flip::FlippedRhs(lhs, rhs)); +} + +template <typename Lhs, typename Rhs> struct BroadcastMulImpl { using ResultBlockType = typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; @@ -494,12 +645,16 @@ template <int N> using RegBufferInt16 = RegisterBuffer<std::int16_t, N>; template <int N> using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>; +template <int N> +using RegBufferInt8 = RegisterBuffer<std::int8_t, N>; template <int R, int C> using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>; template <int R, int C> using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>; template <int R, int C> using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>; +template <int R, int C> +using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>; } // end namespace gemmlowp diff --git a/internal/simd_wrappers_common_neon_sse.h b/internal/simd_wrappers_common_neon_sse.h index 3830eb1..694bf99 100644 --- a/internal/simd_wrappers_common_neon_sse.h +++ b/internal/simd_wrappers_common_neon_sse.h @@ -350,6 +350,210 @@ struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> { } }; +// 4x1 := 4x1 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 1x4 := 1x4 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 4x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 1>, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 + 1x4 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 4>, + RegBlockInt32<1, 4>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 4x4 := 4x4 + 1x4 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>, + RegBlockInt32<1, 4>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[2] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[3] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x4 := 4x4 + 4x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<4, 4>, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[0]); + result.buf.reg[2] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[0]); + return result; + } +}; + +// 8x1 := 8x1 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<8, 1> result; + const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], p); + } + return result; + } +}; + +// 8x1 := 8x1 + 8x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 1>, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 1> result; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +// 8x4 := 8x4 + 1x4 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>, + RegBlockInt32<1, 4>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[2] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[3] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[4] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[5] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[6] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); + result.buf.reg[7] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 8x4 := 8x4 + 8x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<8, 4>, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); + result.buf.reg[2] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[3], rhs.buf.reg[1]); + result.buf.reg[4] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[4], rhs.buf.reg[0]); + result.buf.reg[5] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[5], rhs.buf.reg[1]); + result.buf.reg[6] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[6], rhs.buf.reg[0]); + result.buf.reg[7] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[7], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x8 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>, + RegBlockInt32<1, 8>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 8>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = + SaturatingRoundingDoublingHighMul(lhs.buf.reg[1], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x1 +template <> +struct BroadcastSaturatingRoundingDoublingHighMulImpl<RegBlockInt32<1, 8>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + result.buf.reg[1] = SaturatingRoundingDoublingHighMul( + lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + // 4x1 := 4x1 * 1x1 template <> struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { diff --git a/internal/simd_wrappers_msa.h b/internal/simd_wrappers_msa.h index cf5e8e9..7de01ff 100644 --- a/internal/simd_wrappers_msa.h +++ b/internal/simd_wrappers_msa.h @@ -33,8 +33,7 @@ struct RegisterType<std::int32_t, ScalarCount> { template <int ScalarCount> struct RegisterType<std::int16_t, ScalarCount> { - using Type = - typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type; + using Type = typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type; }; template <int ScalarCount> @@ -69,13 +68,9 @@ inline Int16x8 LoadInt16x8(const Int16x8* src) { return __builtin_msa_ld_h(const_cast<Int16x8*>(src), 0); } -inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) { - __builtin_msa_st_h(value, dst, 0); -} +inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); } -inline void StoreInt16x8(Int16x8* dst, Int16x8 value) { - __builtin_msa_st_h(value, dst, 0); -} +inline void StoreInt16x8(Int16x8* dst, Int16x8 value) { __builtin_msa_st_h(value, dst, 0); } inline Uint8x16 LoadUint8x16(const std::uint8_t* src) { return __builtin_msa_ld_b(const_cast<std::uint8_t*>(src), 0); diff --git a/internal/simd_wrappers_neon.h b/internal/simd_wrappers_neon.h index 2949173..6871055 100644 --- a/internal/simd_wrappers_neon.h +++ b/internal/simd_wrappers_neon.h @@ -25,6 +25,7 @@ using Int32x4 = int32x4_t; using Int16x4 = int16x4_t; using Int16x8 = int16x8_t; using Uint8x8 = uint8x8_t; +using Int8x8 = int8x8_t; template <int ScalarCount> struct RegisterType<std::int32_t, ScalarCount> { @@ -48,6 +49,14 @@ struct RegisterType<std::uint8_t, ScalarCount> { std::uint8_t>::type>::type; }; +template <int ScalarCount> +struct RegisterType<std::int8_t, ScalarCount> { + using Type = typename std::conditional< + ScalarCount >= 8, Int8x8, + typename std::conditional<ScalarCount >= 4, std::int32_t, + std::int8_t>::type>::type; +}; + inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); } inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); } inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); } @@ -92,6 +101,10 @@ inline Int32x4 Min(Int32x4 a, Int32x4 b) { return vminq_s32(a, b); } inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); } +inline Int32x4 Max(Int32x4 a, std::int32_t b) { + return vmaxq_s32(a, vdupq_n_s32(b)); +} + inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) { return vqrdmulhq_n_s32(a, b); } @@ -164,6 +177,17 @@ struct LoadContiguousImpl<RegBlockUint8<8, 8>> { }; template <> +struct LoadContiguousImpl<RegBlockInt8<8, 8>> { + static RegBlockInt8<8, 8> Run(const std::int8_t* src) { + RegBlockInt8<8, 8> result; + for (int i = 0; i < 8; i++) { + result.buf.reg[i] = vld1_s8(src + 8 * i); + } + return result; + } +}; + +template <> struct LoadContiguousImpl<RegBlockInt32<8, 8>> { static RegBlockInt32<8, 8> Run(const std::int32_t* src) { RegBlockInt32<8, 8> result; @@ -174,6 +198,352 @@ struct LoadContiguousImpl<RegBlockInt32<8, 8>> { } }; +// 4x1 := 4x1 + 1x1 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 1x4 := 1x4 + 1x1 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 4x1 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 + 1x4 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 4x4 := 4x4 + 1x4 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x4 := 4x4 + 4x1 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[0]); + result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[0]); + return result; + } +}; + +// 8x1 := 8x1 + 1x1 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<8, 1> result; + const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], p); + } + return result; + } +}; + +// 8x1 := 8x1 + 8x1 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 1> result; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +// 8x4 := 8x4 + 1x4 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); + result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 8x4 := 8x4 + 8x1 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]); + result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[1]); + result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], rhs.buf.reg[0]); + result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], rhs.buf.reg[1]); + result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], rhs.buf.reg[0]); + result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x8 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 8>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x1 +template <> +struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 1x1 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 1x4 := 1x4 + 1x1 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x1 := 4x1 + 4x1 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 1> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 1x4 := 1x4 + 1x4 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>, + RegBlockInt32<1, 4>> { + static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<1, 4> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + return result; + } +}; + +// 4x4 := 4x4 + 1x4 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>, + RegBlockInt32<1, 4>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = + RoundingDivideByPOT(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[2] = + RoundingDivideByPOT(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[3] = + RoundingDivideByPOT(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 4x4 := 4x4 + 4x1 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>, + RegBlockInt32<4, 1>> { + static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, + const RegBlockInt32<4, 1>& rhs) { + RegBlockInt32<4, 4> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[0]); + result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[0]); + return result; + } +}; + +// 8x1 := 8x1 + 1x1 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<8, 1> result; + const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], p); + } + return result; + } +}; + +// 8x1 := 8x1 + 8x1 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 1> result; + for (int i = 0; i < 2; i++) { + result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], rhs.buf.reg[i]); + } + return result; + } +}; + +// 8x4 := 8x4 + 1x4 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>, + RegBlockInt32<1, 4>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<1, 4>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[1] = + RoundingDivideByPOT(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); + result.buf.reg[2] = + RoundingDivideByPOT(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[3] = + RoundingDivideByPOT(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); + result.buf.reg[4] = + RoundingDivideByPOT(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[5] = + RoundingDivideByPOT(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); + result.buf.reg[6] = + RoundingDivideByPOT(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); + result.buf.reg[7] = + RoundingDivideByPOT(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); + return result; + } +}; + +// 8x4 := 8x4 + 8x1 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>, + RegBlockInt32<8, 1>> { + static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, + const RegBlockInt32<8, 1>& rhs) { + RegBlockInt32<8, 4> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]); + result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]); + result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[1]); + result.buf.reg[4] = RoundingDivideByPOT(lhs.buf.reg[4], rhs.buf.reg[0]); + result.buf.reg[5] = RoundingDivideByPOT(lhs.buf.reg[5], rhs.buf.reg[1]); + result.buf.reg[6] = RoundingDivideByPOT(lhs.buf.reg[6], rhs.buf.reg[0]); + result.buf.reg[7] = RoundingDivideByPOT(lhs.buf.reg[7], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x8 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>, + RegBlockInt32<1, 8>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 8>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); + result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]); + return result; + } +}; + +// 1x8 := 1x8 + 1x1 +template <> +struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>, + RegBlockInt32<1, 1>> { + static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, + const RegBlockInt32<1, 1>& rhs) { + RegBlockInt32<1, 8> result; + result.buf.reg[0] = + RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); + result.buf.reg[1] = + RoundingDivideByPOT(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); + return result; + } +}; + } // end namespace gemmlowp #include "simd_wrappers_common_neon_sse.h" diff --git a/internal/unpack.h b/internal/unpack.h index 33aee13..021f4aa 100644 --- a/internal/unpack.h +++ b/internal/unpack.h @@ -98,12 +98,14 @@ void UnpackResultBlock(const SrcMapType& src, const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, int depth, int src_row, int src_col, int src_global_row, int src_global_col, int dst_row, int dst_col) { + using KernelLhsInputScalar = typename KernelFormat::Lhs::InputScalar; using KernelLhsScalar = typename KernelFormat::Lhs::Scalar; + using KernelRhsInputScalar = typename KernelFormat::Rhs::InputScalar; using KernelRhsScalar = typename KernelFormat::Rhs::Scalar; static constexpr int KernelLhsZeroPointInput = - ZeroPointInputValue<KernelLhsScalar>::kValue; + ZeroPointInputValue<KernelLhsInputScalar, KernelLhsScalar>::kValue; static constexpr int KernelRhsZeroPointInput = - ZeroPointInputValue<KernelRhsScalar>::kValue; + ZeroPointInputValue<KernelRhsInputScalar, KernelRhsScalar>::kValue; auto acc = Load<RegisterBlockType>(src, src_row, src_col); const auto& lhs_sums_of_each_slice_block = LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row); diff --git a/meta/multi_thread_common.h b/meta/multi_thread_common.h index 0b35759..b39c3f2 100644 --- a/meta/multi_thread_common.h +++ b/meta/multi_thread_common.h @@ -22,9 +22,15 @@ namespace meta { inline int ResolveMaxThreads(int max_threads) { if (max_threads == 0) { +#ifdef _WIN32 + SYSTEM_INFO sysinfo; + GetSystemInfo(&sysinfo); + return sysinfo.dwNumberOfProcessors; +#else static const int hardware_threads_count = static_cast<int>(sysconf(_SC_NPROCESSORS_CONF)); return hardware_threads_count; +#endif } return max_threads; } diff --git a/profiling/instrumentation.h b/profiling/instrumentation.h index 437fe54..c1f852e 100644 --- a/profiling/instrumentation.h +++ b/profiling/instrumentation.h @@ -108,13 +108,14 @@ struct ScopedLock { // contains pointers to literal strings that were manually entered // in the instrumented code (see ScopedProfilingLabel). struct ProfilingStack { - static const std::size_t kMaxSize = 14; + static const std::size_t kMaxSize = 30; typedef const char* LabelsArrayType[kMaxSize]; LabelsArrayType labels; std::size_t size; Mutex* lock; ProfilingStack() { memset(this, 0, sizeof(ProfilingStack)); } + ~ProfilingStack() { delete lock; } void Push(const char* label) { ScopedLock sl(lock); @@ -171,8 +172,6 @@ struct ThreadInfo { ScopedLock sl(GlobalMutexes::Profiler()); ThreadInfo* self = static_cast<ThreadInfo*>(ptr); ThreadsUnderProfiling().erase(self); - pthread_key_delete(self->key); - delete self->stack.lock; } }; @@ -185,7 +184,11 @@ inline ThreadInfo& ThreadLocalThreadInfo() { } }; - static int key_result = pthread_key_create(&key, DeleteThreadInfo); + // key_result is unused. The purpose of this 'static' local object is + // to have its initializer (the pthread_key_create call) performed exactly + // once, in a way that is guaranteed (since C++11) to be reentrant. + static const int key_result = pthread_key_create(&key, DeleteThreadInfo); + (void)key_result; ThreadInfo* threadInfo = static_cast<ThreadInfo*>(pthread_getspecific(key)); if (!threadInfo) { diff --git a/profiling/pthread_everywhere.h b/profiling/pthread_everywhere.h index df17c6f..2569bbc 100644 --- a/profiling/pthread_everywhere.h +++ b/profiling/pthread_everywhere.h @@ -60,6 +60,9 @@ inline void pthread_cond_init(pthread_cond_t *cond, std::nullptr_t) { *cond = new std::condition_variable; } inline void pthread_cond_signal(pthread_cond_t *cond) { (*cond)->notify_one(); } +inline void pthread_cond_broadcast(pthread_cond_t *cond) { + (*cond)->notify_all(); +} inline void pthread_cond_wait(pthread_cond_t *cond, pthread_mutex_t *mutex) { std::unique_lock<std::mutex> lock(**mutex, std::adopt_lock); (*cond)->wait(lock); diff --git a/public/bit_depth.h b/public/bit_depth.h index 6cb4ecf..412944e 100644 --- a/public/bit_depth.h +++ b/public/bit_depth.h @@ -24,14 +24,15 @@ template <int tMinValue, int tMaxValue> struct OperandRange { static const int kMinValue = tMinValue; static const int kMaxValue = tMaxValue; - static_assert(0 <= kMinValue, ""); static_assert(kMinValue < kMaxValue, ""); - static_assert(kMaxValue <= 255, ""); }; using Uint8Range = OperandRange<0, 255>; using Uint8RangeExcludingZero = OperandRange<1, 255>; +using Int8Range = OperandRange<-128, 127>; +using Int8RangeExcludingLow = OperandRange<-127, 127>; + template <typename tLhsRange, typename tRhsRange> struct BitDepthParams { using LhsRange = tLhsRange; @@ -47,6 +48,11 @@ using DefaultL8R8BitDepthParams = BitDepthParams<Uint8Range, Uint8Range>; using L8R8WithLhsNonzeroBitDepthParams = BitDepthParams<Uint8RangeExcludingZero, Uint8Range>; +// Signed Variant: This allows using faster kernels using signed arithmetic, see +// NEON_64bit_GEMM_Int8Operands_Int32Accumulators_AccumTwoWithin16Bits +using SignedL8R8WithLhsNonzeroBitDepthParams = + BitDepthParams<Int8RangeExcludingLow, Int8Range>; + // Deprecated: when gemmlowp used to allow requantizing 8bit // inputs to less-than-8-bit depths, the public setting allowing // that was DefaultL7R5BitDepthParams. That requantization diff --git a/public/map.h b/public/map.h index 3073e05..fe6bc5c 100644 --- a/public/map.h +++ b/public/map.h @@ -131,6 +131,7 @@ class VectorDup { assert(start >= 0); assert(start + len <= size_); + (void)start; return VectorDup(data_, len); } }; diff --git a/public/output_stages.h b/public/output_stages.h index 1d5fca4..797b662 100644 --- a/public/output_stages.h +++ b/public/output_stages.h @@ -138,12 +138,44 @@ struct OutputStageScaleInt32ByFixedPointAndExponent { std::int32_t result_offset_after_shift; }; +// Variant of OutputStageQuantizeDownInt32ByFixedPoint where the 'shift' +// is not necessarily just a right shift, so we can represent multipliers +// greater than 1. This takes an result_exponent parameter; when it's +// <= 0, this is equivalent to OutputStageQuantizeDownInt32ByFixedPoint +// with result_shift = -result_exponent. +// In the general case, this consists in first left-shifting by +// std::max(result_exponent, 0), before doing the same as +// OutputStageQuantizeDownInt32ByFixedPoint with +// result_shift = std::max(-result_exponent, 0). +// +// Difference from OutputStageScaleInt32ByFixedPointAndExponent here is that +// each row or column of the output (depending on tShape) has its own +// result_fixedpoint_multiplier and result_exponent numbers. +template <VectorShape tShape> +struct OutputStageScaleInt32ByFixedPointAndExponentPC { + VectorMap<const std::int32_t, tShape> result_fixedpoint_multiplier; + VectorMap<const std::int32_t, tShape> result_exponent; + std::int32_t result_offset_after_shift; +}; + // This output stage takes int32 values that are expected to be already // on the final uint8 scale, but not necessarily in the [0..255] range. // It clamps them to the [0..255] range and returns them casted to uint8. struct OutputStageSaturatingCastToUint8 {}; // This output stage takes int32 values that are expected to be already +// on the final int8 scale, but not necessarily in the [-128..127] range. +// It clamps them to the [-128..127] range and returns them casted to int8. +struct OutputStageSaturatingCastToInt8 {}; + +// This output stage takes int32 values that are expected to be already +// in the [0..255] range and returns them casted to uint8. +// This stage can save time if used instead of the +// OutputStageSaturatingCastToUint8 stage immediately after the +// OutputStageClamp stage. +struct OutputStageTruncatingCastToUint8 {}; + +// This output stage takes int32 values that are expected to be already // on the final int16 scale, but not necessarily in the [-32768..32767] range. // It clamps them to the [-32768..32767] range and returns them casted to int16. struct OutputStageSaturatingCastToInt16 {}; |