aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMiao Wang <miaowang@google.com>2018-02-23 23:31:32 +0000
committerandroid-build-merger <android-build-merger@google.com>2018-02-23 23:31:32 +0000
commit0ed4f31d5ced2432473aa7063bc1e28d990ff3f2 (patch)
treea6ece8759b2fc774b39edea08417e08fa633a73c
parent97962621d25000e4eda770f4dd399a4378fd6b8b (diff)
parent1f4ec3258fe3b77841065990a20fe2047464688b (diff)
downloadgemmlowp-oreo-mr1-1.2-iot-release.tar.gz
Rebase gemmlowp to ecae4d1 am: 7d0d5a611e am: 9fa88931b4android-wear-8.0.0_r2android-o-mr1-iot-release-smart-display-r9android-o-mr1-iot-release-smart-display-r8android-o-mr1-iot-release-smart-display-r5android-o-mr1-iot-release-smart-display-r40.1Jandroid-o-mr1-iot-release-smart-display-r4android-o-mr1-iot-release-smart-display-r39android-o-mr1-iot-release-smart-display-r30android-o-mr1-iot-release-smart-display-r3android-o-mr1-iot-release-smart-display-r22android-o-mr1-iot-release-smart-display-r14android-o-mr1-iot-release-smart-clock-r6android-o-mr1-iot-release-smart-clock-r2android-o-mr1-iot-release-smart-clock-fsiandroid-o-mr1-iot-release-smart-clock-fcsandroid-o-mr1-iot-release-cube_r2android-o-mr1-iot-release-cube-fsiandroid-o-mr1-iot-release-cube-fcsandroid-o-mr1-iot-release-1.0.5android-o-mr1-iot-release-1.0.4android-o-mr1-iot-release-1.0.3android-n-iot-release-ihome-igv1android-9.0.0_r47android-9.0.0_r46android-9.0.0_r45android-9.0.0_r44android-9.0.0_r43android-9.0.0_r42android-9.0.0_r41android-9.0.0_r40android-9.0.0_r39android-9.0.0_r38android-9.0.0_r37android-9.0.0_r36android-9.0.0_r35android-9.0.0_r34android-9.0.0_r33android-9.0.0_r32android-9.0.0_r31android-9.0.0_r30android-9.0.0_r22android-9.0.0_r21android-9.0.0_r20android-9.0.0_r19android-9.0.0_r16android-9.0.0_r12android-9.0.0_r11pie-qpr3-s1-releasepie-qpr3-releasepie-qpr3-b-releasepie-qpr2-releasepie-qpr1-s3-releasepie-qpr1-s2-releasepie-qpr1-s1-releasepie-qpr1-releasepie-dr1-releasepie-dr1-devpie-devpie-b4s4-releasepie-b4s4-devoreo-mr1-1.2-iot-releasenougat-iot-releasemaster-cuttlefish-testing-release
am: 1f4ec3258f Change-Id: Icb9df1558e7d87c03080597ffbb5a6212817cba6
-rw-r--r--doc/quantization.md6
-rw-r--r--doc/quantization_example.cc4
-rw-r--r--fixedpoint/fixedpoint.h147
-rw-r--r--fixedpoint/fixedpoint_msa.h354
-rw-r--r--fixedpoint/fixedpoint_neon.h156
-rw-r--r--fixedpoint/fixedpoint_sse.h174
-rw-r--r--internal/common.h40
-rw-r--r--internal/kernel_default.h39
-rw-r--r--internal/kernel_msa.h339
-rw-r--r--internal/kernel_neon.h93
-rw-r--r--internal/multi_thread_gemm.h9
-rw-r--r--internal/output.h62
-rw-r--r--internal/output_msa.h622
-rw-r--r--internal/output_neon.h275
-rw-r--r--internal/output_sse.h186
-rw-r--r--internal/pack.h2
-rw-r--r--internal/pack_msa.h353
-rw-r--r--internal/pack_neon.h8
-rw-r--r--[-rwxr-xr-x]internal/platform.h30
-rw-r--r--internal/simd_wrappers.h6
-rw-r--r--internal/simd_wrappers_msa.h196
-rw-r--r--internal/simd_wrappers_neon.h31
-rw-r--r--internal/simd_wrappers_sse.h26
-rw-r--r--internal/single_thread_gemm.h7
-rw-r--r--meta/multi_thread_common.h9
-rw-r--r--profiling/instrumentation.h13
-rw-r--r--profiling/pthread_everywhere.h42
-rw-r--r--public/output_stages.h32
-rwxr-xr-xscripts/ci-test.sh2
-rw-r--r--standalone/neon-gemm-kernel-benchmark.cc1458
30 files changed, 4526 insertions, 195 deletions
diff --git a/doc/quantization.md b/doc/quantization.md
index 3e0df16..3a8f72b 100644
--- a/doc/quantization.md
+++ b/doc/quantization.md
@@ -301,7 +301,7 @@ the particular quantization paradigm that we detailed above in this document.
The specific output pipeline stage implementing the present quantization
paradigm, i.e. implementing the precise computation detailed in the previous
section (equation (5)), is
-`OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint`.
+`OutputStageQuantizeDownInt32ByFixedPoint`.
Please refer to the comment explaining it in
[public/output_stages.h](../public/output_stages.h).
@@ -313,7 +313,7 @@ The difference between the older legacy quantization paradigm described in
document boils down to the difference between the legacy output stage
implementing it, `OutputStageQuantizeDownInt32ToUint8Scale`, and the new output
stage implementing the new paradigm,
-`OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint`.
+`OutputStageQuantizeDownInt32ByFixedPoint`.
Please refer to the comments in
[public/output_stages.h](../public/output_stages.h) for details about these two
@@ -323,7 +323,7 @@ Issues with the old output stage `OutputStageQuantizeDownInt32ToUint8Scale` are:
1. The int32 accumulators (inputs to the output stage) undergo a plain int32
multiplication with a int32 multiplier, which may overflow. By contrast, in
- the newer `OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint`, this
+ the newer `OutputStageQuantizeDownInt32ByFixedPoint`, this
integer multiplication becomes a fixed-point multiplication and cannot
overflow.
diff --git a/doc/quantization_example.cc b/doc/quantization_example.cc
index 4368de2..d7b147d 100644
--- a/doc/quantization_example.cc
+++ b/doc/quantization_example.cc
@@ -201,7 +201,7 @@ std::ostream& operator<<(std::ostream& s,
//
// This is how to obtain the fixed-point multiplier and right shift
// parameters to pass to
-// OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint.
+// OutputStageQuantizeDownInt32ByFixedPoint.
//
// Note: all this code only needs to run offline to generate the quantized
// neural network workload, not at runtime on the
@@ -347,7 +347,7 @@ int main() {
<< "use quantized arithmetic.\n"
<< std::endl;
- gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
+ gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint
quantize_down_stage;
quantize_down_stage.result_offset_after_shift = result_offset;
quantize_down_stage.result_fixedpoint_multiplier = quantized_multiplier;
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h
index e21337f..d39341b 100644
--- a/fixedpoint/fixedpoint.h
+++ b/fixedpoint/fixedpoint.h
@@ -50,6 +50,12 @@ struct FixedPointRawTypeTraits<std::int32_t> {
static const int kLanes = 1;
};
+template <>
+struct FixedPointRawTypeTraits<std::int16_t> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 1;
+};
+
// Returns a SIMD value duplicating a scalar value across all lanes.
template <typename tRawType>
tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
@@ -217,6 +223,50 @@ inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
return static_cast<std::int32_t>((sum + sign) / 2);
}
+template <>
+inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
+ std::int32_t a32 = a;
+ std::int32_t b32 = b;
+ std::int32_t sum = a32 + b32;
+ std::int32_t sign = sum >= 0 ? 1 : -1;
+ return static_cast<std::int16_t>((sum + sign) / 2);
+}
+
+template <typename IntegerType>
+IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
+ static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ return a;
+}
+
+// So far this is only needed for int16.
+template <>
+inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
+ std::int32_t a32 = a;
+ std::int32_t b32 = b;
+ std::int32_t sum = a32 + b32;
+ return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
+}
+
+// Returns a+b, saturating if the integers are 16bit or narrower,
+// otherwise just a plain addition.
+template <typename IntegerType, bool Is16Bit>
+struct AddSaturatingIf16BitImpl {
+ static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
+};
+template <typename IntegerType>
+struct AddSaturatingIf16BitImpl<IntegerType, true> {
+ static IntegerType Run(IntegerType a, IntegerType b) {
+ return SaturatingAdd(a, b);
+ }
+};
+template <typename IntegerType>
+IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
+ using ScalarType =
+ typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
+ return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
+ b);
+}
+
// Returns the integer that represents the product of two fixed-point
// numbers, interpreting all integers as fixed-point values in the
// interval [-1, 1), rounding to the nearest value, and saturating
@@ -266,14 +316,23 @@ inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
}
+template <>
+inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
+ std::int16_t b) {
+ bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
+ std::int32_t a_32(a);
+ std::int32_t b_32(b);
+ std::int32_t ab_32 = a_32 * b_32;
+ std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
+ std::int16_t ab_x2_high16 =
+ static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
+ return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
+}
+
// Correctly-rounded-to-nearest division by a power-of-two.
// Also known as a rounding arithmetic right shift.
template <typename IntegerType>
inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
- using ScalarIntegerType =
- typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
- static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
- "Currently only supporting int32 scalar and SIMD types");
assert(exponent >= 0);
assert(exponent <= 31);
const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
@@ -304,14 +363,14 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
static IntegerType eval(IntegerType x) {
using ScalarIntegerType =
typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
- static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
- "Currently only supporting int32 scalar and SIMD types");
const IntegerType min =
- Dup<IntegerType>(std::numeric_limits<std::int32_t>::min());
+ Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
const IntegerType max =
- Dup<IntegerType>(std::numeric_limits<std::int32_t>::max());
+ Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
+ const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
- const std::int32_t threshold = ((1 << (31 - Exponent)) - 1);
+ const std::int32_t threshold =
+ ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
const IntegerType positive_mask =
MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
const IntegerType negative_mask =
@@ -425,15 +484,16 @@ class FixedPoint {
static FixedPoint Zero() { return FromScalarRaw(0); }
static FixedPoint One() {
- return FromScalarRaw(kIntegerBits == 0
- ? ScalarRawMax()
- : (ScalarRawType(1) << kFractionalBits));
+ return FromScalarRaw(
+ kIntegerBits == 0
+ ? ScalarRawMax()
+ : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
}
static FixedPoint FromDouble(double x) {
const double min_bound = static_cast<double>(ScalarRawMin());
const double max_bound = static_cast<double>(ScalarRawMax());
- return FromScalarRaw(static_cast<std::int32_t>(std::min(
+ return FromScalarRaw(static_cast<ScalarRawType>(std::min(
std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
min_bound),
max_bound)));
@@ -555,6 +615,22 @@ bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
return !(a == b);
}
+template <typename tRawType, int tIntegerBits>
+FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
+ FixedPoint<tRawType, tIntegerBits> a,
+ FixedPoint<tRawType, tIntegerBits> b) {
+ return FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingAdd(a.raw(), b.raw()));
+}
+
+template <typename tRawType, int tIntegerBits>
+FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
+ FixedPoint<tRawType, tIntegerBits> a,
+ FixedPoint<tRawType, tIntegerBits> b) {
+ return FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ AddSaturatingIf16Bit(a.raw(), b.raw()));
+}
+
// Conversion to floating-point.
template <typename tRawType, int tIntegerBits>
double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
@@ -579,23 +655,41 @@ FixedPoint<tRawType, tIntegerBitsDst> Rescale(
// initialized as real numbers, in a way that does not compile floating-point
// arithmetic in production code, yet still checks agreement with the
// floating-point expressions when asserts are enabled.
+//
+// The raw integer value provided is always a int32, encoding a 32-bit
+// fixed-point value, regardless of the actual Scalar type. This allows
+// writing generic code that applies just as well to the 32-bit and 16-bit
+// cases. In the 16-bit case, the raw integer value is internally
+// rounding-shifted by 16 bits to the right.
+template <typename FixedPointType>
+inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
+ std::int32_t int32_value) {
+ typedef typename FixedPointType::ScalarRawType ScalarRawType;
+ static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
+ return static_cast<ScalarRawType>(
+ RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
+}
#ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
template <typename FixedPointType>
-FixedPointType CheckedFixedPointConstant(
- typename FixedPointType::ScalarRawType raw_value, double double_value) {
- typedef typename FixedPointType::RawType RawType;
+FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
+ double double_value) {
const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
assert(result == FixedPointType::FromDouble(double_value));
return result;
}
-#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
- DoubleValue) \
- (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue))
+#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
+ ScalarRawInt32Value, DoubleValue) \
+ (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \
+ gemmlowp::RescaleConstantInitializer<FixedPointType>( \
+ ScalarRawInt32Value), \
+ DoubleValue))
#else
-#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
- DoubleValue) \
- (FixedPointType::FromScalarRaw(ScalarRawValue))
+#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
+ ScalarRawInt32Value, DoubleValue) \
+ (FixedPointType::FromScalarRaw( \
+ gemmlowp::RescaleConstantInitializer<FixedPointType>( \
+ ScalarRawInt32Value)))
#endif
// Implementation of exponential function.
@@ -620,8 +714,9 @@ FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
SaturatingRoundingMultiplyByPOT<-1>(
((x4_over_4 + x3) * constant_1_over_3) + x2);
- return constant_term +
- constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2);
+ return AddSaturatingIf16Bit(
+ constant_term,
+ constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
}
// Returns exp(x) for x < 0.
@@ -661,7 +756,7 @@ FixedPoint<tRawType, 0> exp_on_negative_values(
#undef GEMMLOWP_EXP_BARREL_SHIFTER
if (kIntegerBits > 5) {
- static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0;
+ static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
const InputF clamp =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
@@ -774,6 +869,8 @@ FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
#include "./fixedpoint_neon.h"
#elif defined(GEMMLOWP_SSE4)
#include "./fixedpoint_sse.h"
+#elif defined(GEMMLOWP_MSA)
+#include "./fixedpoint_msa.h"
#endif
#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h
new file mode 100644
index 0000000..c7a110c
--- /dev/null
+++ b/fixedpoint/fixedpoint_msa.h
@@ -0,0 +1,354 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// fixedpoint_msa.h: optimized MSA specializations of the templates
+// in fixedpoint.h.
+
+#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
+#define GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
+
+#include <msa.h>
+
+namespace gemmlowp {
+
+template <>
+struct FixedPointRawTypeTraits<v4i32> {
+ typedef std::int32_t ScalarRawType;
+ static const int kLanes = 4;
+};
+
+template <>
+struct FixedPointRawTypeTraits<v8i16> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 8;
+};
+
+template <>
+inline v4i32 BitAnd(v4i32 a, v4i32 b) {
+ return reinterpret_cast<v4i32>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v8i16 BitAnd(v8i16 a, v8i16 b) {
+ return reinterpret_cast<v8i16>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v4i32 BitOr(v4i32 a, v4i32 b) {
+ return reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v8i16 BitOr(v8i16 a, v8i16 b) {
+ return reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v4i32 BitXor(v4i32 a, v4i32 b) {
+ return reinterpret_cast<v4i32>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v8i16 BitXor(v8i16 a, v8i16 b) {
+ return reinterpret_cast<v8i16>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(b)));
+}
+
+template <>
+inline v4i32 BitNot(v4i32 a) {
+ return reinterpret_cast<v4i32>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(a)));
+}
+
+template <>
+inline v8i16 BitNot(v8i16 a) {
+ return reinterpret_cast<v8i16>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a),
+ reinterpret_cast<v16u8>(a)));
+}
+
+template <>
+inline v4i32 Add(v4i32 a, v4i32 b) {
+ return __builtin_msa_addv_w(a, b);
+}
+
+template <>
+inline v8i16 Add(v8i16 a, v8i16 b) {
+ return __builtin_msa_addv_h(a, b);
+}
+
+template <>
+inline v4i32 Sub(v4i32 a, v4i32 b) {
+ return __builtin_msa_subv_w(a, b);
+}
+
+template <>
+inline v8i16 Sub(v8i16 a, v8i16 b) {
+ return __builtin_msa_subv_h(a, b);
+}
+
+template <>
+inline v4i32 Neg(v4i32 a) {
+ v4i32 zeroes = __builtin_msa_ldi_w(0);
+ return __builtin_msa_subv_w(zeroes, a);
+}
+
+template <>
+inline v8i16 Neg(v8i16 a) {
+ v8i16 zeroes = __builtin_msa_ldi_h(0);
+ return __builtin_msa_subv_h(zeroes, a);
+}
+
+template <>
+inline v4i32 ShiftLeft(v4i32 a, int offset) {
+ return __builtin_msa_sll_w(a, __builtin_msa_fill_w(offset));
+}
+
+template <>
+inline v8i16 ShiftLeft(v8i16 a, int offset) {
+ return __builtin_msa_sll_h(a, __builtin_msa_fill_h(offset));
+}
+
+template <>
+inline v4i32 ShiftRight(v4i32 a, int offset) {
+ return __builtin_msa_sra_w(a, __builtin_msa_fill_w(offset));
+}
+
+template <>
+inline v8i16 ShiftRight(v8i16 a, int offset) {
+ return __builtin_msa_sra_h(a, __builtin_msa_fill_h(offset));
+}
+
+template <>
+inline v4i32 SelectUsingMask(v4i32 if_mask, v4i32 then_val, v4i32 else_val) {
+ if_mask = reinterpret_cast<v4i32>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask),
+ reinterpret_cast<v16u8>(else_val),
+ reinterpret_cast<v16u8>(then_val)));
+ return if_mask;
+}
+
+template <>
+inline v8i16 SelectUsingMask(v8i16 if_mask, v8i16 then_val, v8i16 else_val) {
+ if_mask = reinterpret_cast<v8i16>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask),
+ reinterpret_cast<v16u8>(else_val),
+ reinterpret_cast<v16u8>(then_val)));
+ return if_mask;
+}
+
+template <>
+inline v4i32 MaskIfEqual(v4i32 a, v4i32 b) {
+ return __builtin_msa_ceq_w(a, b);
+}
+
+template <>
+inline v8i16 MaskIfEqual(v8i16 a, v8i16 b) {
+ return __builtin_msa_ceq_h(a, b);
+}
+
+template <>
+inline v4i32 MaskIfNotEqual(v4i32 a, v4i32 b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
+inline v8i16 MaskIfNotEqual(v8i16 a, v8i16 b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
+inline v4i32 MaskIfZero(v4i32 a) {
+ return __builtin_msa_ceqi_w(a, 0);
+}
+
+template <>
+inline v8i16 MaskIfZero(v8i16 a) {
+ return __builtin_msa_ceqi_h(a, 0);
+}
+
+template <>
+inline v4i32 MaskIfNonZero(v4i32 a) {
+ return BitNot(MaskIfZero(a));
+}
+
+template <>
+inline v8i16 MaskIfNonZero(v8i16 a) {
+ return BitNot(MaskIfZero(a));
+}
+
+template <>
+inline v4i32 MaskIfGreaterThan(v4i32 a, v4i32 b) {
+ return __builtin_msa_clt_s_w(b, a);
+}
+
+template <>
+inline v8i16 MaskIfGreaterThan(v8i16 a, v8i16 b) {
+ return __builtin_msa_clt_s_h(b, a);
+}
+
+template <>
+inline v4i32 MaskIfGreaterThanOrEqual(v4i32 a, v4i32 b) {
+ return __builtin_msa_cle_s_w(b, a);
+}
+
+template <>
+inline v8i16 MaskIfGreaterThanOrEqual(v8i16 a, v8i16 b) {
+ return __builtin_msa_cle_s_h(b, a);
+}
+
+template <>
+inline v4i32 MaskIfLessThan(v4i32 a, v4i32 b) {
+ return __builtin_msa_clt_s_w(a, b);
+}
+
+template <>
+inline v8i16 MaskIfLessThan(v8i16 a, v8i16 b) {
+ return __builtin_msa_clt_s_h(a, b);
+}
+
+template <>
+inline v4i32 MaskIfLessThanOrEqual(v4i32 a, v4i32 b) {
+ return __builtin_msa_cle_s_w(a, b);
+}
+
+template <>
+inline v8i16 MaskIfLessThanOrEqual(v8i16 a, v8i16 b) {
+ return __builtin_msa_cle_s_h(a, b);
+}
+
+template <>
+inline bool All(v4i32 a) {
+ return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a)));
+}
+
+template <>
+inline bool All(v8i16 a) {
+ return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a)));
+}
+
+template <>
+inline bool Any(v4i32 a) {
+ return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a));
+}
+
+template <>
+inline bool Any(v8i16 a) {
+ return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a));
+}
+
+template <>
+inline v4i32 RoundingHalfSum(v4i32 a, v4i32 b) {
+ return __builtin_msa_aver_s_w(a, b);
+}
+
+template <>
+inline v8i16 RoundingHalfSum(v8i16 a, v8i16 b) {
+ return __builtin_msa_aver_s_h(a, b);
+}
+
+template <>
+inline v4i32 SaturatingRoundingDoublingHighMul(v4i32 a, v4i32 b) {
+ return __builtin_msa_mulr_q_w(a, b);
+}
+
+template <>
+inline v8i16 SaturatingRoundingDoublingHighMul(v8i16 a, v8i16 b) {
+ return __builtin_msa_mulr_q_h(a, b);
+}
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, 1> {
+ static v4i32 eval(v4i32 x) {
+ static_assert(Exponent >= 0 && Exponent < 32, "");
+ if (Exponent < 5) {
+ for (int i = 0; i < Exponent; i++) {
+ x = __builtin_msa_adds_s_w(x, x);
+ }
+ return x;
+ } else {
+ // Saturate each signed 32-bit element to (32 - Exponent)
+ // bits (this takes full care of negative elements).
+ v4i32 res = __builtin_msa_sat_s_w(x, 31 - Exponent);
+ // Set tmp to 0x7FFFFFFF for those elements which staturated
+ // to smaller (positive) values and 0 for all others.
+ v4i32 tmp = __builtin_msa_srli_w(__builtin_msa_clt_s_w(res, x), 1);
+ // Shift the saturated elements. The positive saturated elements
+ // will have Exponent trailing zero bits after the shift. Those
+ // need to be ones, not zeroes.
+ res = __builtin_msa_slli_w(res, Exponent);
+ // Finally, set those trailing zero bits to ones.
+ res = reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res),
+ reinterpret_cast<v16u8>(tmp)));
+ return res;
+ }
+ }
+};
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> {
+ static v8i16 eval(v8i16 x) {
+ static_assert(Exponent >= 0 && Exponent < 16, "");
+ if (Exponent < 5) {
+ for (int i = 0; i < Exponent; i++) {
+ x = __builtin_msa_adds_s_h(x, x);
+ }
+ return x;
+ } else {
+ // Saturate each signed 16-bit element to (16 - Exponent)
+ // bits (this takes full care of negative elements).
+ v8i16 res = __builtin_msa_sat_s_h(x, 15 - Exponent);
+ // Set tmp to 0x7FFF for those elements which staturated
+ // to smaller (positive) values and 0 for all others.
+ v8i16 tmp = __builtin_msa_srli_h(__builtin_msa_clt_s_h(res, x), 1);
+ // Shift the saturated elements. The positive saturated elements
+ // will have Exponent trailing zero bits after the shift. Those
+ // need to be ones, not zeroes.
+ res = __builtin_msa_slli_h(res, Exponent);
+ // Finally, set those trailing zero bits to ones.
+ res = reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res),
+ reinterpret_cast<v16u8>(tmp)));
+ return res;
+ }
+ }
+};
+
+// TODO: possibly implement:
+// template <> v4i32 RoundingDivideByPOT(v4i32, int)
+// template <> v8i16 RoundingDivideByPOT(v8i16, int)
+// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1>
+// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1>
+
+template <>
+inline v4i32 Dup<v4i32>(std::int32_t x) {
+ return __builtin_msa_fill_w(x);
+}
+
+template <>
+inline v8i16 Dup<v8i16>(std::int16_t x) {
+ return __builtin_msa_fill_h(x);
+}
+
+// So far this is only needed for int16.
+template <>
+inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) {
+ return __builtin_msa_adds_s_h(a, b);
+ return a;
+}
+
+} // end namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
diff --git a/fixedpoint/fixedpoint_neon.h b/fixedpoint/fixedpoint_neon.h
index 8b23de2..92b349b 100644
--- a/fixedpoint/fixedpoint_neon.h
+++ b/fixedpoint/fixedpoint_neon.h
@@ -29,97 +29,194 @@ struct FixedPointRawTypeTraits<int32x4_t> {
};
template <>
+struct FixedPointRawTypeTraits<int16x8_t> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 8;
+};
+
+template <>
inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) {
return vandq_s32(a, b);
}
template <>
+inline int16x8_t BitAnd(int16x8_t a, int16x8_t b) {
+ return vandq_s16(a, b);
+}
+
+template <>
inline int32x4_t BitOr(int32x4_t a, int32x4_t b) {
return vorrq_s32(a, b);
}
template <>
+inline int16x8_t BitOr(int16x8_t a, int16x8_t b) {
+ return vorrq_s16(a, b);
+}
+
+template <>
inline int32x4_t BitXor(int32x4_t a, int32x4_t b) {
return veorq_s32(a, b);
}
template <>
+inline int16x8_t BitXor(int16x8_t a, int16x8_t b) {
+ return veorq_s16(a, b);
+}
+
+template <>
inline int32x4_t BitNot(int32x4_t a) {
return veorq_s32(a, vdupq_n_s32(-1));
}
template <>
+inline int16x8_t BitNot(int16x8_t a) {
+ return veorq_s16(a, vdupq_n_s16(-1));
+}
+
+template <>
inline int32x4_t Add(int32x4_t a, int32x4_t b) {
return vaddq_s32(a, b);
}
template <>
+inline int16x8_t Add(int16x8_t a, int16x8_t b) {
+ return vaddq_s16(a, b);
+}
+
+template <>
inline int32x4_t Sub(int32x4_t a, int32x4_t b) {
return vsubq_s32(a, b);
}
template <>
+inline int16x8_t Sub(int16x8_t a, int16x8_t b) {
+ return vsubq_s16(a, b);
+}
+
+template <>
inline int32x4_t Neg(int32x4_t a) {
return vnegq_s32(a);
}
template <>
+inline int16x8_t Neg(int16x8_t a) {
+ return vnegq_s16(a);
+}
+
+template <>
inline int32x4_t ShiftLeft(int32x4_t a, int offset) {
return vshlq_s32(a, vdupq_n_s32(offset));
}
template <>
+inline int16x8_t ShiftLeft(int16x8_t a, int offset) {
+ return vshlq_s16(a, vdupq_n_s16(offset));
+}
+
+template <>
inline int32x4_t ShiftRight(int32x4_t a, int offset) {
return vshlq_s32(a, vdupq_n_s32(-offset));
}
template <>
+inline int16x8_t ShiftRight(int16x8_t a, int offset) {
+ return vshlq_s16(a, vdupq_n_s16(-offset));
+}
+
+template <>
inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val,
int32x4_t else_val) {
return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val);
}
template <>
+inline int16x8_t SelectUsingMask(int16x8_t if_mask, int16x8_t then_val,
+ int16x8_t else_val) {
+ return vbslq_s16(vreinterpretq_u16_s16(if_mask), then_val, else_val);
+}
+
+template <>
inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vceqq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfEqual(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vceqq_s16(a, b));
+}
+
+template <>
inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) {
return BitNot(MaskIfEqual(a, b));
}
template <>
+inline int16x8_t MaskIfNotEqual(int16x8_t a, int16x8_t b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
inline int32x4_t MaskIfZero(int32x4_t a) {
return MaskIfEqual(a, vdupq_n_s32(0));
}
template <>
+inline int16x8_t MaskIfZero(int16x8_t a) {
+ return MaskIfEqual(a, vdupq_n_s16(0));
+}
+
+template <>
inline int32x4_t MaskIfNonZero(int32x4_t a) {
return vreinterpretq_s32_u32(vtstq_s32(a, a));
}
template <>
+inline int16x8_t MaskIfNonZero(int16x8_t a) {
+ return vreinterpretq_s16_u16(vtstq_s16(a, a));
+}
+
+template <>
inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vcgtq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfGreaterThan(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vcgtq_s16(a, b));
+}
+
+template <>
inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vcgeq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfGreaterThanOrEqual(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vcgeq_s16(a, b));
+}
+
+template <>
inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vcltq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfLessThan(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vcltq_s16(a, b));
+}
+
+template <>
inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) {
return vreinterpretq_s32_u32(vcleq_s32(a, b));
}
template <>
+inline int16x8_t MaskIfLessThanOrEqual(int16x8_t a, int16x8_t b) {
+ return vreinterpretq_s16_u16(vcleq_s16(a, b));
+}
+
+template <>
inline bool All(int32x4_t a) {
a = vandq_s32(a, vextq_s32(a, a, 1));
a = vandq_s32(a, vextq_s32(a, a, 2));
@@ -127,6 +224,14 @@ inline bool All(int32x4_t a) {
}
template <>
+inline bool All(int16x8_t a) {
+ a = vandq_s16(a, vextq_s16(a, a, 1));
+ a = vandq_s16(a, vextq_s16(a, a, 2));
+ a = vandq_s16(a, vextq_s16(a, a, 4));
+ return vgetq_lane_s16(a, 0);
+}
+
+template <>
inline bool Any(int32x4_t a) {
a = vorrq_s32(a, vextq_s32(a, a, 1));
a = vorrq_s32(a, vextq_s32(a, a, 2));
@@ -134,16 +239,34 @@ inline bool Any(int32x4_t a) {
}
template <>
+inline bool Any(int16x8_t a) {
+ a = vorrq_s16(a, vextq_s16(a, a, 1));
+ a = vorrq_s16(a, vextq_s16(a, a, 2));
+ a = vorrq_s16(a, vextq_s16(a, a, 4));
+ return vgetq_lane_s16(a, 0);
+}
+
+template <>
inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) {
return vrhaddq_s32(a, b);
}
template <>
+inline int16x8_t RoundingHalfSum(int16x8_t a, int16x8_t b) {
+ return vrhaddq_s16(a, b);
+}
+
+template <>
inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) {
return vqrdmulhq_s32(a, b);
}
template <>
+inline int16x8_t SaturatingRoundingDoublingHighMul(int16x8_t a, int16x8_t b) {
+ return vqrdmulhq_s16(a, b);
+}
+
+template <>
inline int32x4_t RoundingDivideByPOT(int32x4_t x, int exponent) {
const int32x4_t shift_vec = vdupq_n_s32(-exponent);
const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
@@ -151,6 +274,14 @@ inline int32x4_t RoundingDivideByPOT(int32x4_t x, int exponent) {
return vrshlq_s32(fixed_up_x, shift_vec);
}
+template <>
+inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) {
+ const int16x8_t shift_vec = vdupq_n_s16(-exponent);
+ const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
+ const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
+ return vrshlq_s16(fixed_up_x, shift_vec);
+}
+
template <int Exponent>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
@@ -165,11 +296,36 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, -1> {
}
};
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, 1> {
+ static int16x8_t eval(int16x8_t x) { return vqshlq_n_s16(x, Exponent); }
+};
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, -1> {
+ static int16x8_t eval(int16x8_t x) {
+ const int16x8_t fixup = vshrq_n_s16(x, 15);
+ const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
+ return vrshrq_n_s16(fixed_up_x, -Exponent);
+ }
+};
+
template <>
inline int32x4_t Dup<int32x4_t>(std::int32_t x) {
return vdupq_n_s32(x);
}
+template <>
+inline int16x8_t Dup<int16x8_t>(std::int16_t x) {
+ return vdupq_n_s16(x);
+}
+
+// So far this is only needed for int16.
+template <>
+inline int16x8_t SaturatingAdd(int16x8_t a, int16x8_t b) {
+ return vqaddq_s16(a, b);
+}
+
} // end namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h
index 3f2654d..ba990f0 100644
--- a/fixedpoint/fixedpoint_sse.h
+++ b/fixedpoint/fixedpoint_sse.h
@@ -23,6 +23,22 @@
namespace gemmlowp {
+// SSE intrinsics are not finely typed: there is a single __m128i vector
+// type that does not distinguish between "int32x4" and "int16x8" use
+// cases, unlike the NEON equivalents. Because we had initially focused
+// on int32x4, we did not pay attention and specialized these fixedpoint
+// templates directly for __m128i hardcoding the int32x4 semantics,
+// not leaving room for int16x8 semantics. Amending that by adding a separate
+// data type, int16x8_m128i, that wraps __m128i while being a separate
+// type.
+struct int16x8_m128i {
+ int16x8_m128i() {}
+ explicit int16x8_m128i(__m128i w) : v(w) {}
+ ~int16x8_m128i() {}
+
+ __m128i v;
+};
+
template <>
struct FixedPointRawTypeTraits<__m128i> {
typedef std::int32_t ScalarRawType;
@@ -30,61 +46,125 @@ struct FixedPointRawTypeTraits<__m128i> {
};
template <>
+struct FixedPointRawTypeTraits<int16x8_m128i> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 8;
+};
+
+template <>
inline __m128i BitAnd(__m128i a, __m128i b) {
return _mm_and_si128(a, b);
}
template <>
+inline int16x8_m128i BitAnd(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_and_si128(a.v, b.v));
+}
+
+template <>
inline __m128i BitOr(__m128i a, __m128i b) {
return _mm_or_si128(a, b);
}
template <>
+inline int16x8_m128i BitOr(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_or_si128(a.v, b.v));
+}
+
+template <>
inline __m128i BitXor(__m128i a, __m128i b) {
return _mm_xor_si128(a, b);
}
template <>
+inline int16x8_m128i BitXor(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_xor_si128(a.v, b.v));
+}
+
+template <>
inline __m128i BitNot(__m128i a) {
return _mm_andnot_si128(a, _mm_set1_epi32(-1));
}
template <>
+inline int16x8_m128i BitNot(int16x8_m128i a) {
+ return int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1)));
+}
+
+template <>
inline __m128i Add(__m128i a, __m128i b) {
return _mm_add_epi32(a, b);
}
template <>
+inline int16x8_m128i Add(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_add_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i Mul(__m128i a, __m128i b) {
return _mm_mullo_epi32(a, b);
}
template <>
+inline int16x8_m128i Mul(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_mullo_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i Sub(__m128i a, __m128i b) {
return _mm_sub_epi32(a, b);
}
template <>
+inline int16x8_m128i Sub(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_sub_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i Neg(__m128i a) {
return _mm_sign_epi32(a, _mm_set1_epi32(-1));
}
template <>
+inline int16x8_m128i Neg(int16x8_m128i a) {
+ return int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1)));
+}
+
+template <>
inline __m128i ShiftLeft(__m128i a, int offset) {
return _mm_slli_epi32(a, offset);
}
template <>
+inline int16x8_m128i ShiftLeft(int16x8_m128i a, int offset) {
+ return int16x8_m128i(_mm_slli_epi16(a.v, offset));
+}
+
+template <>
inline __m128i ShiftRight(__m128i a, int offset) {
return _mm_srai_epi32(a, offset);
}
template <>
+inline int16x8_m128i ShiftRight(int16x8_m128i a, int offset) {
+ return int16x8_m128i(_mm_srai_epi16(a.v, offset));
+}
+
+template <>
inline __m128i SelectUsingMask(__m128i if_mask, __m128i then_val,
__m128i else_val) {
- return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(else_val),
- _mm_castsi128_ps(then_val),
- _mm_castsi128_ps(if_mask)));
+ // borrowed from Intel's arm_neon_sse.h header.
+ return _mm_or_si128(_mm_and_si128(if_mask, then_val),
+ _mm_andnot_si128(if_mask, else_val));
+}
+
+template <>
+inline int16x8_m128i SelectUsingMask(int16x8_m128i if_mask,
+ int16x8_m128i then_val,
+ int16x8_m128i else_val) {
+ // borrowed from Intel's arm_neon_sse.h header.
+ return int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v));
}
template <>
@@ -93,40 +173,81 @@ inline __m128i MaskIfEqual(__m128i a, __m128i b) {
}
template <>
+inline int16x8_m128i MaskIfEqual(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i MaskIfNotEqual(__m128i a, __m128i b) {
return BitNot(MaskIfEqual(a, b));
}
template <>
+inline int16x8_m128i MaskIfNotEqual(int16x8_m128i a, int16x8_m128i b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
inline __m128i MaskIfZero(__m128i a) {
return MaskIfEqual(a, _mm_set1_epi32(0));
}
template <>
+inline int16x8_m128i MaskIfZero(int16x8_m128i a) {
+ return MaskIfEqual(a, int16x8_m128i(_mm_set1_epi16(0)));
+}
+
+template <>
inline __m128i MaskIfNonZero(__m128i a) {
return MaskIfNotEqual(a, _mm_set1_epi32(0));
}
template <>
+inline int16x8_m128i MaskIfNonZero(int16x8_m128i a) {
+ return MaskIfNotEqual(a, int16x8_m128i(_mm_set1_epi16(0)));
+}
+
+template <>
inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) {
return _mm_cmpgt_epi32(a, b);
}
template <>
+inline int16x8_m128i MaskIfGreaterThan(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i MaskIfLessThan(__m128i a, __m128i b) {
return _mm_cmplt_epi32(a, b);
}
template <>
+inline int16x8_m128i MaskIfLessThan(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_cmplt_epi16(a.v, b.v));
+}
+
+template <>
inline __m128i MaskIfGreaterThanOrEqual(__m128i a, __m128i b) {
return BitNot(MaskIfLessThan(a, b));
}
template <>
+inline int16x8_m128i MaskIfGreaterThanOrEqual(int16x8_m128i a,
+ int16x8_m128i b) {
+ return BitNot(MaskIfLessThan(a, b));
+}
+
+template <>
inline __m128i MaskIfLessThanOrEqual(__m128i a, __m128i b) {
return BitNot(MaskIfGreaterThan(a, b));
}
+template <>
+inline int16x8_m128i MaskIfLessThanOrEqual(int16x8_m128i a, int16x8_m128i b) {
+ return BitNot(MaskIfGreaterThan(a, b));
+}
+
/* Assumptions:
- All and Any are used on masks.
- masks are all_ones for true lanes, all_zeroes otherwise.
@@ -139,8 +260,18 @@ inline bool All(__m128i a) {
}
template <>
+inline bool All(int16x8_m128i a) {
+ return _mm_testc_si128(a.v, a.v);
+}
+
+template <>
inline bool Any(__m128i a) {
- return BitNot(_mm_testz_si128(a, a));
+ return !_mm_testz_si128(a, a);
+}
+
+template <>
+inline bool Any(int16x8_m128i a) {
+ return !_mm_testz_si128(a.v, a.v);
}
template <>
@@ -171,6 +302,18 @@ inline __m128i RoundingHalfSum(__m128i a, __m128i b) {
}
template <>
+inline int16x8_m128i RoundingHalfSum(int16x8_m128i a, int16x8_m128i b) {
+ // Idea: go to unsigned to use _mm_avg_epu16,
+ // borrowed from Intel's arm_neon_sse.h header.
+ __m128i constant_neg_32768 = _mm_set1_epi16(-32768);
+ __m128i a_unsigned = _mm_sub_epi16(a.v, constant_neg_32768);
+ __m128i b_unsigned = _mm_sub_epi16(b.v, constant_neg_32768);
+ __m128i avg_unsigned = _mm_avg_epu16(a_unsigned, b_unsigned);
+ __m128i avg = _mm_add_epi16(avg_unsigned, constant_neg_32768);
+ return int16x8_m128i(avg);
+}
+
+template <>
inline __m128i SaturatingRoundingDoublingHighMul(__m128i a, __m128i b) {
__m128i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
__m128i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
@@ -209,10 +352,33 @@ inline __m128i SaturatingRoundingDoublingHighMul(__m128i a, __m128i b) {
}
template <>
+inline int16x8_m128i SaturatingRoundingDoublingHighMul(int16x8_m128i a,
+ int16x8_m128i b) {
+ // Idea: use _mm_mulhrs_epi16 then saturate with a bit-operation,
+ // borrowed from Intel's arm_neon_sse.h header.
+ __m128i result_unsaturated = _mm_mulhrs_epi16(a.v, b.v);
+ __m128i saturation_mask =
+ _mm_cmpeq_epi16(result_unsaturated, _mm_set1_epi16(0x8000));
+ __m128i result = _mm_xor_si128(result_unsaturated, saturation_mask);
+ return int16x8_m128i(result);
+}
+
+template <>
inline __m128i Dup<__m128i>(std::int32_t x) {
return _mm_set1_epi32(x);
}
+template <>
+inline int16x8_m128i Dup<int16x8_m128i>(std::int16_t x) {
+ return int16x8_m128i(_mm_set1_epi16(x));
+}
+
+// So far this is only needed for int16.
+template <>
+inline int16x8_m128i SaturatingAdd(int16x8_m128i a, int16x8_m128i b) {
+ return int16x8_m128i(_mm_adds_epi16(a.v, b.v));
+}
+
} // end namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_
diff --git a/internal/common.h b/internal/common.h
index 9de151b..26b6713 100644
--- a/internal/common.h
+++ b/internal/common.h
@@ -55,6 +55,19 @@
#define GEMMLOWP_ARM
#endif
+// Detect MIPS, 32-bit or 64-bit
+#if defined(__mips) && !defined(__LP64__)
+#define GEMMLOWP_MIPS_32
+#endif
+
+#if defined(__mips) && defined(__LP64__)
+#define GEMMLOWP_MIPS_64
+#endif
+
+#if defined(GEMMLOWP_MIPS_32) || defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MIPS
+#endif
+
// Detect x86, 32-bit or 64-bit
#if defined(__i386__) || defined(_M_IX86) || defined(_X86_) || defined(__i386)
#define GEMMLOWP_X86_32
@@ -87,6 +100,23 @@
#define GEMMLOWP_NEON_64
#endif
+// Detect MIPS MSA.
+// Limit MSA optimizations to little-endian CPUs for now.
+// TODO: Perhaps, eventually support MSA optimizations on big-endian CPUs?
+#if defined(GEMMLOWP_MIPS) && (__mips_isa_rev >= 5) && defined(__mips_msa) && \
+ defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+#define GEMMLOWP_MSA
+#endif
+
+// Convenience MIPS MSA tokens for 32-bit or 64-bit.
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_32)
+#define GEMMLOWP_MSA_32
+#endif
+
+#if defined(GEMMLOWP_MSA) && defined(GEMMLOWP_MIPS_64)
+#define GEMMLOWP_MSA_64
+#endif
+
// Detect SSE.
#ifdef __SSE4_1__
#define GEMMLOWP_SSE4
@@ -97,7 +127,8 @@
#endif
// Convenience SSE4 tokens for 32-bit or 64-bit
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32)
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_32) && \
+ !defined(GEMMLOWP_DISABLE_SSE4)
#define GEMMLOWP_SSE4_32
#endif
@@ -105,7 +136,8 @@
#define GEMMLOWP_SSE3_32
#endif
-#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64)
+#if defined(GEMMLOWP_SSE4) && defined(GEMMLOWP_X86_64) && \
+ !defined(GEMMLOWP_DISABLE_SSE4)
#define GEMMLOWP_SSE4_64
#endif
@@ -178,6 +210,10 @@ const int kDefaultL2CacheSize = 4 * 1024 * 1024;
// x86-32 and not Android. Same as x86-64 but less bullish.
const int kDefaultL1CacheSize = 32 * 1024;
const int kDefaultL2CacheSize = 2 * 1024 * 1024;
+#elif defined(GEMMLOWP_MIPS)
+// MIPS and not Android. TODO: MIPS and Android?
+const int kDefaultL1CacheSize = 32 * 1024;
+const int kDefaultL2CacheSize = 1024 * 1024;
#else
// Less common hardware. Maybe some unusual or older or embedded thing.
// Assume smaller caches, but don't depart too far from what we do
diff --git a/internal/kernel_default.h b/internal/kernel_default.h
index 7037bda..a919ffe 100644
--- a/internal/kernel_default.h
+++ b/internal/kernel_default.h
@@ -18,18 +18,13 @@
#ifndef GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_
#define GEMMLOWP_INTERNAL_KERNEL_DEFAULT_H_
-#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#endif
-
#include "../public/bit_depth.h"
#include "common.h"
#include "kernel_reference.h"
namespace gemmlowp {
-template <bool MaxProductIsLessThan4096,
- bool LhsAlwaysNonzero>
+template <bool MaxProductIsLessThan4096, bool LhsAlwaysNonzero>
struct DefaultKernelImpl {};
// Partial specialization implementing the logic that if we want to use
@@ -56,12 +51,12 @@ struct DefaultKernel
} // end namespace gemmlowp
-#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \
- LhsAlwaysNonzero, Kernel) \
- namespace gemmlowp { \
- template <> \
- struct DefaultKernelImpl<MaxProductIsLessThan4096, \
- LhsAlwaysNonzero> : Kernel {}; \
+#define GEMMLOWP_SET_DEFAULT_KERNEL(MaxProductIsLessThan4096, \
+ LhsAlwaysNonzero, Kernel) \
+ namespace gemmlowp { \
+ template <> \
+ struct DefaultKernelImpl<MaxProductIsLessThan4096, LhsAlwaysNonzero> \
+ : Kernel {}; \
}
#if defined GEMMLOWP_NEON_32
@@ -76,6 +71,9 @@ GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, NEON_64_Kernel12x8Depth2)
GEMMLOWP_SET_DEFAULT_KERNEL(false, true,
NEON_64bit_GEMM_Int8Operands_LhsNonzero)
+#elif defined(GEMMLOWP_MSA)
+#include "kernel_msa.h"
+GEMMLOWP_SET_DEFAULT_KERNEL(false, false, MSA_Kernel12x8Depth2)
#elif defined GEMMLOWP_SSE4_32
#include "kernel_sse.h"
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2)
@@ -83,23 +81,6 @@ GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_32_Kernel4x4Depth2)
#include "kernel_sse.h"
GEMMLOWP_SET_DEFAULT_KERNEL(false, false, SSE4_64_Kernel12x4Depth2)
#else
-#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
-#if defined __ARM_ARCH_5TE__
-// SIMD is not available on this platform. The slow fallback will be used.
-// Don't require GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK because there's nothing
-// the user can do about it.
-#elif defined __powerpc__
-// There is currently no fast kernel using SIMD instructions on POWER. Don't
-// require GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK because there's nothing the user
-// can do about it.
-#else
-#error \
- "SIMD not enabled, you'd be getting a slow software fallback. Consider \
-enabling SIMD extensions (for example using -msse4 if you're on modern x86). \
-If that's not an option, and you would like to continue with the \
-slow fallback, define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK."
-#endif
-#endif
#include "kernel_reference.h"
namespace gemmlowp {
typedef ReferenceKernel<KernelFormat<
diff --git a/internal/kernel_msa.h b/internal/kernel_msa.h
new file mode 100644
index 0000000..4985b73
--- /dev/null
+++ b/internal/kernel_msa.h
@@ -0,0 +1,339 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// kernel_msa.h: a collection of MSA optimized kernels.
+// Check in kernel_default.h which one(s) are actually used by default.
+// Others are mere experiments; they are still covered by tests
+// in case they might be useful some day.
+
+#ifndef GEMMLOWP_INTERNAL_KERNEL_MSA_H_
+#define GEMMLOWP_INTERNAL_KERNEL_MSA_H_
+
+#include "kernel.h"
+
+#include <msa.h>
+#include <cassert>
+
+namespace gemmlowp {
+
+#ifdef GEMMLOWP_MSA
+
+// Some convenience macros to hide differences between MIPS32 and MIPS64.
+#ifdef GEMMLOWP_MIPS_64
+#define GEMMLOWP_MIPS_XADDU "daddu"
+#define GEMMLOWP_MIPS_XADDIU "daddiu"
+#define GEMMLOWP_MIPS_XSLL "dsll"
+#else
+#define GEMMLOWP_MIPS_XADDU "addu"
+#define GEMMLOWP_MIPS_XADDIU "addiu"
+#define GEMMLOWP_MIPS_XSLL "sll"
+#endif
+
+// Our main GEMM kernel.
+struct MSA_Kernel12x8Depth2 : KernelBase {
+ typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
+ KernelSideFormat<CellFormat<4, 2>, 2> >
+ Format;
+
+ const char* Name() const override { return "MSA, 12x8, depth 2"; }
+
+ // TODO(benoitjacob): reorder function arguments so dst comes last
+ void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
+ std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
+ const std::uint8_t* rhs_ptr, std::size_t start_depth,
+ std::size_t run_depth) const override {
+ ScopedProfilingLabel label("optimized kernel (MSA 12x8)");
+// See comments above for why we need local numerical labels in our asm.
+#define GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "1"
+#define GEMMLOWP_LABEL_BEFORE_LOOP "2"
+#define GEMMLOWP_LABEL_LOOP "3"
+#define GEMMLOWP_LABEL_AFTER_LOOP "4"
+
+ assert(dst_row_stride == 1);
+ asm volatile(
+ // Set a temp to all zeroes.
+ "ldi.b $w31, 0\n"
+
+ // Multiply dst_col_stride by 4 == sizeof(int32) to use
+ // it as a byte offset below.
+ GEMMLOWP_MIPS_XSLL
+ " %[dst_col_stride], %[dst_col_stride], 2\n"
+
+ // Check if start_depth==0 to decide whether we will clear
+ // accumulators or load existing accumulators.
+ "beqz %[start_depth], " GEMMLOWP_LABEL_CLEAR_ACCUMULATORS "f\n"
+
+ // Load accumulators (start_depth != 0).
+ GEMMLOWP_MIPS_XADDU
+ " $a0, %[dst_ptr], %[dst_col_stride]\n"
+ "ld.w $w0, (0*16)(%[dst_ptr])\n"
+ "ld.w $w4, (1*16)(%[dst_ptr])\n"
+ "ld.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w1, (0*16)($a0)\n"
+ "ld.w $w5, (1*16)($a0)\n"
+ "ld.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w2, (0*16)($a1)\n"
+ "ld.w $w6, (1*16)($a1)\n"
+ "ld.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w3, (0*16)($a0)\n"
+ "ld.w $w7, (1*16)($a0)\n"
+ "ld.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w12, (0*16)($a1)\n"
+ "ld.w $w16, (1*16)($a1)\n"
+ "ld.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "ld.w $w13, (0*16)($a0)\n"
+ "ld.w $w17, (1*16)($a0)\n"
+ "ld.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "ld.w $w14, (0*16)($a1)\n"
+ "ld.w $w18, (1*16)($a1)\n"
+ "ld.w $w22, (2*16)($a1)\n"
+ "ld.w $w15, (0*16)($a0)\n"
+ "ld.w $w19, (1*16)($a0)\n"
+ "ld.w $w23, (2*16)($a0)\n"
+ "b " GEMMLOWP_LABEL_BEFORE_LOOP "f\n"
+
+ GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+ ":\n"
+ // Clear accumulators (start_depth == 0).
+ "ldi.w $w0, 0\n"
+ "ldi.w $w4, 0\n"
+ "ldi.w $w8, 0\n"
+ "ldi.w $w1, 0\n"
+ "ldi.w $w5, 0\n"
+ "ldi.w $w9, 0\n"
+ "ldi.w $w2, 0\n"
+ "ldi.w $w6, 0\n"
+ "ldi.w $w10, 0\n"
+ "ldi.w $w3, 0\n"
+ "ldi.w $w7, 0\n"
+ "ldi.w $w11, 0\n"
+ "ldi.w $w12, 0\n"
+ "ldi.w $w16, 0\n"
+ "ldi.w $w20, 0\n"
+ "ldi.w $w13, 0\n"
+ "ldi.w $w17, 0\n"
+ "ldi.w $w21, 0\n"
+ "ldi.w $w14, 0\n"
+ "ldi.w $w18, 0\n"
+ "ldi.w $w22, 0\n"
+ "ldi.w $w15, 0\n"
+ "ldi.w $w19, 0\n"
+ "ldi.w $w23, 0\n"
+
+ GEMMLOWP_LABEL_BEFORE_LOOP ":\n"
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+ // Overview of register layout:
+ //
+ // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
+ // (each register contains 4 replicas of a pair of elements).
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
+ // A 12x8 block of accumulators is stored in 32bit in w0-w23.
+ //
+ // +------+------+------+------+
+ // Rhs |w27 |w28 |w29 |w30 |
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +------+------+------+------+
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // +---+ - - - - +------+------+------+------+
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // +---+ - - - - +------+------+------+------+
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // +---+ - - - - +------+------+------+------+
+ //
+ // Accumulators
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w24, 0(%[lhs_ptr])\n"
+ "ld.b $w25, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for the first half of depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the first half of depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w24, $w31, $w24\n"
+ "ilvl.b $w26, $w31, $w25\n"
+ "ilvr.b $w25, $w31, $w25\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w27, $w31, $w24\n"
+ "ilvl.d $w28, $w31, $w25\n"
+ "ilvl.d $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w27, $w24\n"
+ "ilvr.h $w25, $w28, $w25\n"
+ "ilvr.h $w26, $w29, $w26\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for
+ // dpadd_u.w (for the first half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Load 4 bytes of rhs[] for the second half of depth 0.
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the second half of depth 1.
+ "lbu $v0, 12(%[rhs_ptr])\n"
+ "lbu $v1, 13(%[rhs_ptr])\n"
+ "lbu $t8, 14(%[rhs_ptr])\n"
+ "lbu $t9, 15(%[rhs_ptr])\n"
+
+ // First half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w24, $w27\n"
+ "dpadd_u.w $w4, $w25, $w27\n"
+ "dpadd_u.w $w8, $w26, $w27\n"
+ "dpadd_u.w $w1, $w24, $w28\n"
+ "dpadd_u.w $w5, $w25, $w28\n"
+ "dpadd_u.w $w9, $w26, $w28\n"
+ "dpadd_u.w $w2, $w24, $w29\n"
+ "dpadd_u.w $w6, $w25, $w29\n"
+ "dpadd_u.w $w10, $w26, $w29\n"
+ "dpadd_u.w $w3, $w24, $w30\n"
+ "dpadd_u.w $w7, $w25, $w30\n"
+ "dpadd_u.w $w11, $w26, $w30\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for
+ // dpadd_u.w (for the second half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Second half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w12, $w24, $w27\n"
+ "dpadd_u.w $w16, $w25, $w27\n"
+ "dpadd_u.w $w20, $w26, $w27\n"
+ "dpadd_u.w $w13, $w24, $w28\n"
+ "dpadd_u.w $w17, $w25, $w28\n"
+ "dpadd_u.w $w21, $w26, $w28\n"
+ "dpadd_u.w $w14, $w24, $w29\n"
+ "dpadd_u.w $w18, $w25, $w29\n"
+ "dpadd_u.w $w22, $w26, $w29\n"
+ "dpadd_u.w $w15, $w24, $w30\n"
+ "dpadd_u.w $w19, $w25, $w30\n"
+ "dpadd_u.w $w23, $w26, $w30\n"
+
+ GEMMLOWP_MIPS_XADDIU " %[run_depth], -2\n" GEMMLOWP_MIPS_XADDIU
+ " %[lhs_ptr], 24\n" GEMMLOWP_MIPS_XADDIU
+ " %[rhs_ptr], 16\n"
+ "bnez %[run_depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ GEMMLOWP_LABEL_AFTER_LOOP ":\n"
+
+ // Store accumulators.
+ GEMMLOWP_MIPS_XADDU
+ " $a0, %[dst_ptr], %[dst_col_stride]\n"
+ "st.w $w0, (0*16)(%[dst_ptr])\n"
+ "st.w $w4, (1*16)(%[dst_ptr])\n"
+ "st.w $w8, (2*16)(%[dst_ptr])\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w1, (0*16)($a0)\n"
+ "st.w $w5, (1*16)($a0)\n"
+ "st.w $w9, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w2, (0*16)($a1)\n"
+ "st.w $w6, (1*16)($a1)\n"
+ "st.w $w10, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w3, (0*16)($a0)\n"
+ "st.w $w7, (1*16)($a0)\n"
+ "st.w $w11, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w12, (0*16)($a1)\n"
+ "st.w $w16, (1*16)($a1)\n"
+ "st.w $w20, (2*16)($a1)\n" GEMMLOWP_MIPS_XADDU
+ " $a1, $a0, %[dst_col_stride]\n"
+ "st.w $w13, (0*16)($a0)\n"
+ "st.w $w17, (1*16)($a0)\n"
+ "st.w $w21, (2*16)($a0)\n" GEMMLOWP_MIPS_XADDU
+ " $a0, $a1, %[dst_col_stride]\n"
+ "st.w $w14, (0*16)($a1)\n"
+ "st.w $w18, (1*16)($a1)\n"
+ "st.w $w22, (2*16)($a1)\n"
+ "st.w $w15, (0*16)($a0)\n"
+ "st.w $w19, (1*16)($a0)\n"
+ "st.w $w23, (2*16)($a0)\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [run_depth] "+r"(run_depth),
+ [dst_col_stride] "+r"(dst_col_stride)
+ : // inputs
+ [dst_ptr] "r"(dst_ptr),
+ [start_depth] "r"(start_depth)
+ : // clobbers
+ "memory", "v0", "v1", "a0", "a1", "a2", "a3", "t8", "t9", "$f0", "$f1",
+ "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9", "$f10", "$f11",
+ "$f12", "$f13", "$f14", "$f15", "$f16", "$f17", "$f18", "$f19", "$f20",
+ "$f21", "$f22", "$f23", "$f24", "$f25", "$f26", "$f27", "$f28", "$f29",
+ "$f30", "$f31");
+
+#undef GEMMLOWP_LABEL_CLEAR_ACCUMULATORS
+#undef GEMMLOWP_LABEL_BEFORE_LOOP
+#undef GEMMLOWP_LABEL_LOOP
+#undef GEMMLOWP_LABEL_AFTER_LOOP
+ }
+};
+
+#undef GEMMLOWP_MIPS_XADDU
+#undef GEMMLOWP_MIPS_XADDIU
+#undef GEMMLOWP_MIPS_XSLL
+
+#endif // GEMMLOWP_MSA
+
+} // namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_KERNEL_MSA_H_
diff --git a/internal/kernel_neon.h b/internal/kernel_neon.h
index 5c253ba..3cd48f4 100644
--- a/internal/kernel_neon.h
+++ b/internal/kernel_neon.h
@@ -421,52 +421,52 @@ struct NEON_32_Kernel12x4Depth2Assuming12BitProducts : KernelBase {
GEMMLOWP_LOOP_NEON_32_KERNEL_12X4_DEPTH2_ASSUMING_12BIT_PRODUCTS
":\n"
-// Overview of register layout:
-//
-// Registers q4--q16 are the local 16-bit accumulators.
-// However, each entry in the result matrix is represented
-// by *two* local 16-bit accumulators: one for even levels
-// of depth and one for odd levels of depth. These correspond
-// to the scalars at even and odd indices within each q-register.
-// Thus we effectively use 32 bits of register space for each
-// entry in the result matrix. The accumulators register layout
-// is the same as was described above for the global 32-bit
-// accumulators (3 cells of size 4x4 in diagonal-major order)
-// with the only difference that instead of 32bit values we have
-// pairs of 16bit values.
-//
-// A 2x4 cell of Rhs is stored in 8bit in d0.
-// A 12x2 block of 3 4x2 cells Lhs is stored in 8bit in d1--d3.
-//
-// +--------+--------+--------+--------+
-// |d0[0] |d0[2] |d0[4] |d0[6] |
-// Rhs +--------+--------+--------+--------+
-// |d0[1] |d0[3] |d0[5] |d0[7] |
-// +--------+--------+--------+--------+
-//
-// | | | | |
-//
-// Lhs | | | | |
-//
-// +-----+-----+ - - - +--------+--------+--------+--------+
-// |d1[0]|d1[1]| |q4[0,1] |q5[0,1] |q6[0,1] |q7[0,1] |
-// |d1[2]|d1[3]| |q7[2,3] |q4[2,3] |q5[2,3] |q6[2,3] |
-// |d1[4]|d1[5]| |q6[4,5] |q7[4,5] |q4[4,5] |q5[4,5] |
-// |d1[6]|d1[7]| |q5[6,7] |q6[6,7] |q7[6,7] |q4[6,7] |
-// +-----+-----+ - - - +--------+--------+--------+--------+
-// |d2[0]|d2[1]| |q8[0,1] |q8[0,1] |q8[0,1] |q8[0,1] |
-// |d2[2]|d2[3]| |q9[2,3] |q9[2,3] |q9[2,3] |q9[2,3] |
-// |d2[4]|d2[5]| |q10[4,5]|q10[4,5]|q10[4,5]|q10[4,5]|
-// |d2[6]|d2[7]| |q11[6,7]|q11[6,7]|q11[6,7]|q11[6,7]|
-// +-----+-----+ - - - +--------+--------+--------+--------+
-// |d3[0]|d3[1]| |q12[0,1]|q12[0,1]|q12[0,1]|q12[0,1]|
-// |d3[2]|d3[3]| |q13[2,3]|q13[2,3]|q13[2,3]|q13[2,3]|
-// |d3[4]|d3[5]| |q14[4,5]|q14[4,5]|q14[4,5]|q14[4,5]|
-// |d3[6]|d3[7]| |q15[6,7]|q15[6,7]|q15[6,7]|q15[6,7]|
-// +-----+-----+ - - - +--------+--------+--------+--------+
-//
-// Local 16-bit accumulators
-// Note: 2 scalars per matrix entry
+ // Overview of register layout:
+ //
+ // Registers q4--q16 are the local 16-bit accumulators.
+ // However, each entry in the result matrix is represented
+ // by *two* local 16-bit accumulators: one for even levels
+ // of depth and one for odd levels of depth. These correspond
+ // to the scalars at even and odd indices within each q-register.
+ // Thus we effectively use 32 bits of register space for each
+ // entry in the result matrix. The accumulators register layout
+ // is the same as was described above for the global 32-bit
+ // accumulators (3 cells of size 4x4 in diagonal-major order)
+ // with the only difference that instead of 32bit values we have
+ // pairs of 16bit values.
+ //
+ // A 2x4 cell of Rhs is stored in 8bit in d0.
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 8bit in d1--d3.
+ //
+ // +--------+--------+--------+--------+
+ // |d0[0] |d0[2] |d0[4] |d0[6] |
+ // Rhs +--------+--------+--------+--------+
+ // |d0[1] |d0[3] |d0[5] |d0[7] |
+ // +--------+--------+--------+--------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +-----+-----+ - - - +--------+--------+--------+--------+
+ // |d1[0]|d1[1]| |q4[0,1] |q5[0,1] |q6[0,1] |q7[0,1] |
+ // |d1[2]|d1[3]| |q7[2,3] |q4[2,3] |q5[2,3] |q6[2,3] |
+ // |d1[4]|d1[5]| |q6[4,5] |q7[4,5] |q4[4,5] |q5[4,5] |
+ // |d1[6]|d1[7]| |q5[6,7] |q6[6,7] |q7[6,7] |q4[6,7] |
+ // +-----+-----+ - - - +--------+--------+--------+--------+
+ // |d2[0]|d2[1]| |q8[0,1] |q8[0,1] |q8[0,1] |q8[0,1] |
+ // |d2[2]|d2[3]| |q9[2,3] |q9[2,3] |q9[2,3] |q9[2,3] |
+ // |d2[4]|d2[5]| |q10[4,5]|q10[4,5]|q10[4,5]|q10[4,5]|
+ // |d2[6]|d2[7]| |q11[6,7]|q11[6,7]|q11[6,7]|q11[6,7]|
+ // +-----+-----+ - - - +--------+--------+--------+--------+
+ // |d3[0]|d3[1]| |q12[0,1]|q12[0,1]|q12[0,1]|q12[0,1]|
+ // |d3[2]|d3[3]| |q13[2,3]|q13[2,3]|q13[2,3]|q13[2,3]|
+ // |d3[4]|d3[5]| |q14[4,5]|q14[4,5]|q14[4,5]|q14[4,5]|
+ // |d3[6]|d3[7]| |q15[6,7]|q15[6,7]|q15[6,7]|q15[6,7]|
+ // +-----+-----+ - - - +--------+--------+--------+--------+
+ //
+ // Local 16-bit accumulators
+ // Note: 2 scalars per matrix entry
#define GEMMLOWP_ACCUMULATE_2_LEVELS_OF_DEPTH \
/* Load 3 Lhs cells of size 4x2 */ \
@@ -1261,7 +1261,6 @@ struct NEON_64bit_GEMM_Int8Operands_LhsNonzero : KernelBase {
}
};
-
// Our main GEMM kernel.
struct NEON_64_Kernel12x8Depth2 : KernelBase {
typedef KernelFormat<KernelSideFormat<CellFormat<4, 2>, 3>,
diff --git a/internal/multi_thread_gemm.h b/internal/multi_thread_gemm.h
index df7387a..791402f 100644
--- a/internal/multi_thread_gemm.h
+++ b/internal/multi_thread_gemm.h
@@ -149,9 +149,7 @@ T WaitForVariableChange(volatile T* var, T initial_value, pthread_cond_t* cond,
// to have finished working.
class BlockingCounter {
public:
- BlockingCounter()
- : count_(0),
- initial_count_(0) {
+ BlockingCounter() : count_(0), initial_count_(0) {
pthread_cond_init(&cond_, nullptr);
pthread_mutex_init(&mutex_, nullptr);
}
@@ -548,11 +546,6 @@ class MultiThreadGemmContext : public MultiThreadGemmContextBase {
WorkersPool workers_pool_;
};
-// Needed by chrome native builds
-#ifndef _SC_NPROCESSORS_CONF
-#define _SC_NPROCESSORS_CONF _SC_NPROCESSORS_ONLN
-#endif
-
// Determines how many threads should be used for a given Gemm
// operation.
template <int KernelRows>
diff --git a/internal/output.h b/internal/output.h
index 8ccb8ee..dcfe2b5 100644
--- a/internal/output.h
+++ b/internal/output.h
@@ -119,12 +119,12 @@ struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>,
template <int Size>
struct OutputStageEvalBufferImpl<
- OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
+ OutputStageQuantizeDownInt32ByFixedPoint,
RegisterBuffer<std::int32_t, Size>> {
typedef RegisterBuffer<std::int32_t, Size> InputType;
typedef RegisterBuffer<std::int32_t, Size> OutputType;
- typedef OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint OutputStage;
+ typedef OutputStageQuantizeDownInt32ByFixedPoint OutputStage;
OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
@@ -146,6 +146,39 @@ struct OutputStageEvalBufferImpl<
const OutputStage& output_stage;
};
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageScaleInt32ByFixedPointAndExponent,
+ RegisterBuffer<std::int32_t, Size>> {
+ typedef RegisterBuffer<std::int32_t, Size> InputType;
+ typedef RegisterBuffer<std::int32_t, Size> OutputType;
+
+ typedef OutputStageScaleInt32ByFixedPointAndExponent OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {
+ left_shift = std::max(0, output_stage.result_exponent);
+ right_shift = std::max(0, -output_stage.result_exponent);
+ }
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ using RegisterType = typename InputType::RegisterType;
+ const RegisterType result_offset_after_shift =
+ Dup<RegisterType>(output_stage.result_offset_after_shift);
+ for (int i = 0; i < InputType::kRegisterCount; i++) {
+ const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul(
+ ShiftLeft(input.reg[i], left_shift),
+ output_stage.result_fixedpoint_multiplier);
+ output.reg[i] = Add(RoundingDivideByPOT(mulhigh_val, right_shift),
+ result_offset_after_shift);
+ }
+ return output;
+ }
+
+ const OutputStage& output_stage;
+ int left_shift;
+ int right_shift;
+};
+
// Implementation of OutputStageSaturatingCastToUint8 for scalar data
template <int Size>
struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
@@ -169,6 +202,29 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
}
};
+// Implementation of OutputStageSaturatingCastToInt16 for scalar data
+template <int Size>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegisterBuffer<std::int32_t, Size>> {
+ typedef RegisterBuffer<std::int32_t, Size> InputType;
+ typedef RegisterBuffer<std::int16_t, Size> OutputType;
+ static_assert(InputType::kRegisterLanes == 1,
+ "This path is only for scalar values");
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ for (int i = 0; i < InputType::kRegisterCount; i++) {
+ std::int32_t data = input.reg[i];
+ output.reg[i] = data > 32767 ? 32767 : data < -32768 ? -32768 : data;
+ }
+ return output;
+ }
+};
+
template <int Rows, int Cols, typename VectorType>
struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>,
RegisterBlock<std::int32_t, Rows, Cols>> {
@@ -430,6 +486,8 @@ struct OutputPipelineExecutor {
#include "output_neon.h"
#elif defined(GEMMLOWP_SSE4)
#include "output_sse.h"
+#elif defined(GEMMLOWP_MSA)
+#include "output_msa.h"
#endif
#endif // GEMMLOWP_INTERNAL_OUTPUT_H_
diff --git a/internal/output_msa.h b/internal/output_msa.h
new file mode 100644
index 0000000..4c8eb5d
--- /dev/null
+++ b/internal/output_msa.h
@@ -0,0 +1,622 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// output_msa.h: optimized MSA specializations of the templates in output.h.
+
+#ifndef GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
+#define GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
+
+#include "output.h"
+
+#include <msa.h>
+
+namespace gemmlowp {
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferUint8<4> OutputType;
+
+ typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Signed saturate each 32-bit element to 9 bits
+ // (this takes full care of non-negative elements).
+ v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8);
+ // Pack every 32-bit element into 16 bits.
+ tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp)));
+ // Detect negative elements with arithmetic shift right (we
+ // get a 16-bit mask of all zeroes or all ones for every element).
+ v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15);
+ // Zero out negative elements.
+ signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
+ reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0));
+ // Pack every element into 8 bits.
+ tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
+ reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
+ // Return 4 uint8_t elements as uint32_t.
+ output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferUint8<8> OutputType;
+
+ typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Signed saturate each 32-bit element to 9 bits
+ // (this takes full care of non-negative elements).
+ v4i32 tmp_lo = __builtin_msa_sat_s_w(input.reg[0], 8);
+ v4i32 tmp_hi = __builtin_msa_sat_s_w(input.reg[1], 8);
+ // Pack every 32-bit element into 16 bits,
+ // combining all 8 elements into one vector.
+ tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
+ reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo)));
+ // Detect negative elements with arithmetic shift right (we
+ // get a 16-bit mask of all zeroes or all ones for every element).
+ v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15);
+ // Zero out negative elements.
+ signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
+ reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0));
+ // Pack every element into 8 bits.
+ tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
+ reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
+ // Return 8 uint8_t elements as 2 uint32_t's.
+ output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0);
+ output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1);
+ return output;
+ }
+};
+
+#define GEMMLOWP_MIPS_SAT_U8_16(out, in0, in1, in2, in3) \
+ { \
+ v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 8); \
+ v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 8); \
+ v4i32 tmp2 = __builtin_msa_sat_s_w(in2, 8); \
+ v4i32 tmp3 = __builtin_msa_sat_s_w(in3, 8); \
+ tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \
+ reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0))); \
+ tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \
+ reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2))); \
+ v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15); \
+ v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15); \
+ signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \
+ reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \
+ signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b( \
+ reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \
+ signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b( \
+ reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0))); \
+ out = reinterpret_cast<v16i8>(signs0); \
+ }
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferUint8<16> OutputType;
+
+ typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
+ input.reg[2], input.reg[3]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferUint8<32> OutputType;
+
+ typedef OutputStageSaturatingCastToUint8 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
+ input.reg[2], input.reg[3]);
+ GEMMLOWP_MIPS_SAT_U8_16(output.reg[1], input.reg[4], input.reg[5],
+ input.reg[6], input.reg[7]);
+ return output;
+ }
+};
+
+#undef GEMMLOWP_MIPS_SAT_U8_16
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferInt16<4> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ // Signed saturate each 32-bit element to 16 bits.
+ v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(
+ input.reg[0], 15));
+ output.reg[0] = __builtin_msa_copy_s_h(tmp, 0);
+ output.reg[1] = __builtin_msa_copy_s_h(tmp, 2);
+ output.reg[2] = __builtin_msa_copy_s_h(tmp, 4);
+ output.reg[3] = __builtin_msa_copy_s_h(tmp, 6);
+ return output;
+ }
+};
+
+#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \
+ { \
+ v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \
+ v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \
+ out = __builtin_msa_pckev_h( \
+ reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \
+ }
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferInt16<8> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferInt16<16> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferInt16<32> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[2], input.reg[4], input.reg[5]);
+ GEMMLOWP_MIPS_SAT_I16_8(output.reg[3], input.reg[6], input.reg[7]);
+ return output;
+ }
+};
+
+#undef GEMMLOWP_MIPS_SAT_I16_8
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
+ static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
+ *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
+ *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
+ *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
+ static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+ StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
+ } else {
+ *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
+ *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
+ *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
+ *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
+ *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
+ *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
+ *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
+ *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
+ static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
+ int col) {
+ *dst->data(row + 0, col) = src.buf.reg[0];
+ *dst->data(row + 1, col) = src.buf.reg[1];
+ *dst->data(row + 2, col) = src.buf.reg[2];
+ *dst->data(row + 3, col) = src.buf.reg[3];
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
+ static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ *dst->data(row + 0, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 0);
+ *dst->data(row + 1, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 1);
+ *dst->data(row + 2, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 2);
+ *dst->data(row + 3, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 3);
+ *dst->data(row + 4, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 4);
+ *dst->data(row + 5, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 5);
+ *dst->data(row + 6, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 6);
+ *dst->data(row + 7, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 7);
+ }
+ }
+};
+
+inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
+ RegBlockInt32<4, 4> result;
+ v4i32 tmp0, tmp1;
+ tmp0 = __builtin_msa_ilvr_w(src.buf.reg[1], src.buf.reg[0]);
+ tmp1 = __builtin_msa_ilvr_w(src.buf.reg[3], src.buf.reg[2]);
+ result.buf.reg[0] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
+ result.buf.reg[1] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
+ tmp0 = __builtin_msa_ilvl_w(src.buf.reg[1], src.buf.reg[0]);
+ tmp1 = __builtin_msa_ilvl_w(src.buf.reg[3], src.buf.reg[2]);
+ result.buf.reg[2] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
+ result.buf.reg[3] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
+ return result;
+}
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
+ static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ const auto transpose = Transpose(src);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
+ static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
+ int col) {
+ std::int16_t buf[16];
+ StoreInt16x8(buf + 0, src.buf.reg[0]);
+ StoreInt16x8(buf + 8, src.buf.reg[1]);
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ *dst->data(row + i, col + j) = buf[i + 4 * j];
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
+ static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
+ StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
+ }
+ } else {
+ RegBlockInt32<4, 4> top;
+ top.buf.reg[0] = src.buf.reg[0];
+ top.buf.reg[1] = src.buf.reg[2];
+ top.buf.reg[2] = src.buf.reg[4];
+ top.buf.reg[3] = src.buf.reg[6];
+ const auto transpose_top = Transpose(top);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
+ }
+ RegBlockInt32<4, 4> bottom;
+ bottom.buf.reg[0] = src.buf.reg[1];
+ bottom.buf.reg[1] = src.buf.reg[3];
+ bottom.buf.reg[2] = src.buf.reg[5];
+ bottom.buf.reg[3] = src.buf.reg[7];
+ const auto transpose_bottom = Transpose(bottom);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
+ static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 4; i++) {
+ StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ std::int16_t buf[32];
+ StoreInt16x8(buf + 0, src.buf.reg[0]);
+ StoreInt16x8(buf + 8, src.buf.reg[1]);
+ StoreInt16x8(buf + 16, src.buf.reg[2]);
+ StoreInt16x8(buf + 24, src.buf.reg[3]);
+ for (int i = 0; i < 8; i++) {
+ for (int j = 0; j < 4; j++) {
+ *dst->data(row + i, col + j) = buf[i + 8 * j];
+ }
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
+ static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 8; i++) {
+ StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
+ StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
+ }
+ } else {
+ RegBlockInt32<4, 4> top_left;
+ top_left.buf.reg[0] = src.buf.reg[0];
+ top_left.buf.reg[1] = src.buf.reg[2];
+ top_left.buf.reg[2] = src.buf.reg[4];
+ top_left.buf.reg[3] = src.buf.reg[6];
+ const auto transpose_top_left = Transpose(top_left);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
+ }
+ RegBlockInt32<4, 4> bottom_left;
+ bottom_left.buf.reg[0] = src.buf.reg[1];
+ bottom_left.buf.reg[1] = src.buf.reg[3];
+ bottom_left.buf.reg[2] = src.buf.reg[5];
+ bottom_left.buf.reg[3] = src.buf.reg[7];
+ const auto transpose_bottom_left = Transpose(bottom_left);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + 4 + i, col),
+ transpose_bottom_left.buf.reg[i]);
+ }
+ RegBlockInt32<4, 4> top_right;
+ top_right.buf.reg[0] = src.buf.reg[8];
+ top_right.buf.reg[1] = src.buf.reg[10];
+ top_right.buf.reg[2] = src.buf.reg[12];
+ top_right.buf.reg[3] = src.buf.reg[14];
+ const auto transpose_top_right = Transpose(top_right);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + i, col + 4),
+ transpose_top_right.buf.reg[i]);
+ }
+ RegBlockInt32<4, 4> bottom_right;
+ bottom_right.buf.reg[0] = src.buf.reg[9];
+ bottom_right.buf.reg[1] = src.buf.reg[11];
+ bottom_right.buf.reg[2] = src.buf.reg[13];
+ bottom_right.buf.reg[3] = src.buf.reg[15];
+ const auto transpose_bottom_right = Transpose(bottom_right);
+ for (int i = 0; i < 4; i++) {
+ StoreInt32x4(dst->data(row + 4 + i, col + 4),
+ transpose_bottom_right.buf.reg[i]);
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
+ static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 8; i++) {
+ StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ // top-left 4x4
+ v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1],
+ src.buf.reg[0]));
+ v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3],
+ src.buf.reg[2]));
+ v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0));
+ v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0));
+ // top-right 4x4
+ v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5],
+ src.buf.reg[4]));
+ v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7],
+ src.buf.reg[6]));
+ v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2));
+ v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2));
+ // bottom-left 4x4
+ v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1],
+ src.buf.reg[0]));
+ v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3],
+ src.buf.reg[2]));
+ v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4));
+ v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4));
+ // bottom-right 4x4
+ v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5],
+ src.buf.reg[4]));
+ v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7],
+ src.buf.reg[6]));
+ v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6));
+ v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6));
+
+ StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u2, u0)));
+ StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u2, u0)));
+ StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u3, u1)));
+ StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u3, u1)));
+ StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u6, u4)));
+ StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u6, u4)));
+ StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvr_d(u7, u5)));
+ StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>(
+ __builtin_msa_ilvl_d(u7, u5)));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
+ static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
+ *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
+ *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
+ *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
+ } else {
+ StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
+ static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
+ int col) {
+ const std::uint32_t src_reg = src.buf.reg[0];
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row + i, col) = (src_reg >> (8 * i));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
+ static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
+ int col) {
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
+ }
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
+ static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
+ int col) {
+ for (int i = 0; i < 4; i++) {
+ *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
+ static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
+ int col) {
+ std::uint8_t buf[16];
+ StoreUint8x16(buf, src.buf.reg[0]);
+ for (int c = 0; c < 4; c++) {
+ for (int r = 0; r < 4; r++) {
+ *dst->data(row + r, col + c) = buf[r + 4 * c];
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
+ static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
+ int col) {
+ std::uint8_t buf[32];
+ StoreUint8x16(buf, src.buf.reg[0]);
+ StoreUint8x16(buf + 16, src.buf.reg[1]);
+ for (int c = 0; c < 4; c++) {
+ for (int r = 0; r < 8; r++) {
+ *dst->data(row + r, col + c) = buf[r + 8 * c];
+ }
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
+ static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
+ int col) {
+ std::uint8_t buf[64];
+ StoreUint8x16(buf, src.buf.reg[0]);
+ StoreUint8x16(buf + 16, src.buf.reg[1]);
+ StoreUint8x16(buf + 32, src.buf.reg[2]);
+ StoreUint8x16(buf + 48, src.buf.reg[3]);
+ for (int c = 0; c < 8; c++) {
+ for (int r = 0; r < 8; r++) {
+ *dst->data(row + r, col + c) = buf[r + 8 * c];
+ }
+ }
+ }
+};
+
+} // namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
diff --git a/internal/output_neon.h b/internal/output_neon.h
index 7e111e5..911fed0 100644
--- a/internal/output_neon.h
+++ b/internal/output_neon.h
@@ -107,6 +107,85 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
}
};
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferInt16<4> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] = vqmovn_s32(input.reg[0]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferInt16<8> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] =
+ vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferInt16<16> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] =
+ vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+ output.reg[1] =
+ vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferInt16<32> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] =
+ vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
+ output.reg[1] =
+ vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
+ output.reg[2] =
+ vcombine_s16(vqmovn_s32(input.reg[4]), vqmovn_s32(input.reg[5]));
+ output.reg[3] =
+ vcombine_s16(vqmovn_s32(input.reg[6]), vqmovn_s32(input.reg[7]));
+ return output;
+ }
+};
+
template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
@@ -115,14 +194,48 @@ struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
} else {
- *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
- *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
- *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
- *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
- *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
- *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
- *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
- *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
+ vst1q_lane_s32(dst->data(row + 0, col), src.buf.reg[0], 0);
+ vst1q_lane_s32(dst->data(row + 1, col), src.buf.reg[0], 1);
+ vst1q_lane_s32(dst->data(row + 2, col), src.buf.reg[0], 2);
+ vst1q_lane_s32(dst->data(row + 3, col), src.buf.reg[0], 3);
+ vst1q_lane_s32(dst->data(row + 4, col), src.buf.reg[1], 0);
+ vst1q_lane_s32(dst->data(row + 5, col), src.buf.reg[1], 1);
+ vst1q_lane_s32(dst->data(row + 6, col), src.buf.reg[1], 2);
+ vst1q_lane_s32(dst->data(row + 7, col), src.buf.reg[1], 3);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
+ static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt16x4(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ vst1_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0);
+ vst1_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1);
+ vst1_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2);
+ vst1_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3);
+ }
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
+ static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ vst1q_lane_s16(dst->data(row + 0, col), src.buf.reg[0], 0);
+ vst1q_lane_s16(dst->data(row + 1, col), src.buf.reg[0], 1);
+ vst1q_lane_s16(dst->data(row + 2, col), src.buf.reg[0], 2);
+ vst1q_lane_s16(dst->data(row + 3, col), src.buf.reg[0], 3);
+ vst1q_lane_s16(dst->data(row + 4, col), src.buf.reg[0], 4);
+ vst1q_lane_s16(dst->data(row + 5, col), src.buf.reg[0], 5);
+ vst1q_lane_s16(dst->data(row + 6, col), src.buf.reg[0], 6);
+ vst1q_lane_s16(dst->data(row + 7, col), src.buf.reg[0], 7);
}
}
};
@@ -157,6 +270,35 @@ struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
+ static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ vst1_s16(dst->data(row, col + 0), vget_low_s16(src.buf.reg[0]));
+ vst1_s16(dst->data(row, col + 1), vget_high_s16(src.buf.reg[0]));
+ vst1_s16(dst->data(row, col + 2), vget_low_s16(src.buf.reg[1]));
+ vst1_s16(dst->data(row, col + 3), vget_high_s16(src.buf.reg[1]));
+ } else {
+ const int16x4x2_t t0 =
+ vtrn_s16(vget_low_s16(src.buf.reg[0]), vget_high_s16(src.buf.reg[0]));
+ const int16x4x2_t t1 =
+ vtrn_s16(vget_low_s16(src.buf.reg[1]), vget_high_s16(src.buf.reg[1]));
+ const int32x4x2_t t =
+ vtrnq_s32(vreinterpretq_s32_s16(vcombine_s16(t0.val[0], t0.val[1])),
+ vreinterpretq_s32_s16(vcombine_s16(t1.val[0], t1.val[1])));
+ vst1_s16(dst->data(row + 0, col),
+ vget_low_s16(vreinterpretq_s16_s32(t.val[0])));
+ vst1_s16(dst->data(row + 1, col),
+ vget_high_s16(vreinterpretq_s16_s32(t.val[0])));
+ vst1_s16(dst->data(row + 2, col),
+ vget_low_s16(vreinterpretq_s16_s32(t.val[1])));
+ vst1_s16(dst->data(row + 3, col),
+ vget_high_s16(vreinterpretq_s16_s32(t.val[1])));
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
int col) {
@@ -192,6 +334,42 @@ struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
+ static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]);
+ vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]);
+ vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]);
+ vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]);
+ } else {
+ const int16x8x2_t t0 = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]);
+ const int16x8x2_t t1 = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]);
+ const int32x4x2_t u0 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[0]),
+ vreinterpretq_s32_s16(t1.val[0]));
+ const int32x4x2_t u1 = vtrnq_s32(vreinterpretq_s32_s16(t0.val[1]),
+ vreinterpretq_s32_s16(t1.val[1]));
+ vst1_s16(dst->data(row + 0, col),
+ vget_low_s16(vreinterpretq_s16_s32(u0.val[0])));
+ vst1_s16(dst->data(row + 1, col),
+ vget_low_s16(vreinterpretq_s16_s32(u1.val[0])));
+ vst1_s16(dst->data(row + 2, col),
+ vget_low_s16(vreinterpretq_s16_s32(u0.val[1])));
+ vst1_s16(dst->data(row + 3, col),
+ vget_low_s16(vreinterpretq_s16_s32(u1.val[1])));
+ vst1_s16(dst->data(row + 4, col),
+ vget_high_s16(vreinterpretq_s16_s32(u0.val[0])));
+ vst1_s16(dst->data(row + 5, col),
+ vget_high_s16(vreinterpretq_s16_s32(u1.val[0])));
+ vst1_s16(dst->data(row + 6, col),
+ vget_high_s16(vreinterpretq_s16_s32(u0.val[1])));
+ vst1_s16(dst->data(row + 7, col),
+ vget_high_s16(vreinterpretq_s16_s32(u1.val[1])));
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
int col) {
@@ -281,6 +459,23 @@ struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<1, 4>, DstType> {
+ static void Run(const RegBlockInt16<1, 4>& src, DstType* dst, int row,
+ int col) {
+ std::int16_t* dst_ptr = dst->data(row, col);
+ if (DstType::kOrder == MapOrder::RowMajor) {
+ vst1_s16(dst_ptr, src.buf.reg[0]);
+ } else {
+ int col_stride = dst->cols_stride();
+ vst1_lane_s16(dst_ptr + 0 * col_stride, src.buf.reg[0], 0);
+ vst1_lane_s16(dst_ptr + 1 * col_stride, src.buf.reg[0], 1);
+ vst1_lane_s16(dst_ptr + 2 * col_stride, src.buf.reg[0], 2);
+ vst1_lane_s16(dst_ptr + 3 * col_stride, src.buf.reg[0], 3);
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
int col) {
@@ -427,6 +622,70 @@ struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
}
};
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
+ static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ vst1q_s16(dst->data(row, col + 0), src.buf.reg[0]);
+ vst1q_s16(dst->data(row, col + 1), src.buf.reg[1]);
+ vst1q_s16(dst->data(row, col + 2), src.buf.reg[2]);
+ vst1q_s16(dst->data(row, col + 3), src.buf.reg[3]);
+ vst1q_s16(dst->data(row, col + 4), src.buf.reg[4]);
+ vst1q_s16(dst->data(row, col + 5), src.buf.reg[5]);
+ vst1q_s16(dst->data(row, col + 6), src.buf.reg[6]);
+ vst1q_s16(dst->data(row, col + 7), src.buf.reg[7]);
+ } else {
+ int16x8x2_t a[4];
+ a[0] = vtrnq_s16(src.buf.reg[0], src.buf.reg[1]);
+ a[1] = vtrnq_s16(src.buf.reg[2], src.buf.reg[3]);
+ a[2] = vtrnq_s16(src.buf.reg[4], src.buf.reg[5]);
+ a[3] = vtrnq_s16(src.buf.reg[6], src.buf.reg[7]);
+ int32x4x2_t b[4];
+ b[0] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[0]),
+ vreinterpretq_s32_s16(a[1].val[0]));
+ b[1] = vtrnq_s32(vreinterpretq_s32_s16(a[0].val[1]),
+ vreinterpretq_s32_s16(a[1].val[1]));
+ b[2] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[0]),
+ vreinterpretq_s32_s16(a[3].val[0]));
+ b[3] = vtrnq_s32(vreinterpretq_s32_s16(a[2].val[1]),
+ vreinterpretq_s32_s16(a[3].val[1]));
+ vst1_s16(dst->data(row + 0, col + 0),
+ vget_low_s16(vreinterpretq_s16_s32(b[0].val[0])));
+ vst1_s16(dst->data(row + 0, col + 4),
+ vget_low_s16(vreinterpretq_s16_s32(b[2].val[0])));
+ vst1_s16(dst->data(row + 1, col + 0),
+ vget_low_s16(vreinterpretq_s16_s32(b[1].val[0])));
+ vst1_s16(dst->data(row + 1, col + 4),
+ vget_low_s16(vreinterpretq_s16_s32(b[3].val[0])));
+ vst1_s16(dst->data(row + 2, col + 0),
+ vget_low_s16(vreinterpretq_s16_s32(b[0].val[1])));
+ vst1_s16(dst->data(row + 2, col + 4),
+ vget_low_s16(vreinterpretq_s16_s32(b[2].val[1])));
+ vst1_s16(dst->data(row + 3, col + 0),
+ vget_low_s16(vreinterpretq_s16_s32(b[1].val[1])));
+ vst1_s16(dst->data(row + 3, col + 4),
+ vget_low_s16(vreinterpretq_s16_s32(b[3].val[1])));
+ vst1_s16(dst->data(row + 4, col + 0),
+ vget_high_s16(vreinterpretq_s16_s32(b[0].val[0])));
+ vst1_s16(dst->data(row + 4, col + 4),
+ vget_high_s16(vreinterpretq_s16_s32(b[2].val[0])));
+ vst1_s16(dst->data(row + 5, col + 0),
+ vget_high_s16(vreinterpretq_s16_s32(b[1].val[0])));
+ vst1_s16(dst->data(row + 5, col + 4),
+ vget_high_s16(vreinterpretq_s16_s32(b[3].val[0])));
+ vst1_s16(dst->data(row + 6, col + 0),
+ vget_high_s16(vreinterpretq_s16_s32(b[0].val[1])));
+ vst1_s16(dst->data(row + 6, col + 4),
+ vget_high_s16(vreinterpretq_s16_s32(b[2].val[1])));
+ vst1_s16(dst->data(row + 7, col + 0),
+ vget_high_s16(vreinterpretq_s16_s32(b[1].val[1])));
+ vst1_s16(dst->data(row + 7, col + 4),
+ vget_high_s16(vreinterpretq_s16_s32(b[3].val[1])));
+ }
+ }
+};
+
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
diff --git a/internal/output_sse.h b/internal/output_sse.h
index 5c06253..75aebfd 100644
--- a/internal/output_sse.h
+++ b/internal/output_sse.h
@@ -103,6 +103,82 @@ struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
}
};
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<4>> {
+ typedef RegBufferInt32<4> InputType;
+ typedef RegBufferInt16<4> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]);
+ output.reg[0] = _mm_extract_epi16(res_16, 0);
+ output.reg[1] = _mm_extract_epi16(res_16, 1);
+ output.reg[2] = _mm_extract_epi16(res_16, 2);
+ output.reg[3] = _mm_extract_epi16(res_16, 3);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<8>> {
+ typedef RegBufferInt32<8> InputType;
+ typedef RegBufferInt16<8> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<16>> {
+ typedef RegBufferInt32<16> InputType;
+ typedef RegBufferInt16<16> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
+ output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]);
+ return output;
+ }
+};
+
+template <>
+struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
+ RegBufferInt32<32>> {
+ typedef RegBufferInt32<32> InputType;
+ typedef RegBufferInt16<32> OutputType;
+
+ typedef OutputStageSaturatingCastToInt16 OutputStage;
+
+ OutputStageEvalBufferImpl(const OutputStage&) {}
+
+ OutputType Eval(InputType input) const {
+ OutputType output;
+ output.reg[0] = _mm_packs_epi32(input.reg[0], input.reg[1]);
+ output.reg[1] = _mm_packs_epi32(input.reg[2], input.reg[3]);
+ output.reg[2] = _mm_packs_epi32(input.reg[4], input.reg[5]);
+ output.reg[3] = _mm_packs_epi32(input.reg[6], input.reg[7]);
+ return output;
+ }
+};
+
template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
@@ -138,6 +214,36 @@ struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
}
};
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
+ static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
+ int col) {
+ *dst->data(row + 0, col) = src.buf.reg[0];
+ *dst->data(row + 1, col) = src.buf.reg[1];
+ *dst->data(row + 2, col) = src.buf.reg[2];
+ *dst->data(row + 3, col) = src.buf.reg[3];
+ }
+};
+
+template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
+ static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
+ } else {
+ *dst->data(row + 0, col) = _mm_extract_epi16(src.buf.reg[0], 0);
+ *dst->data(row + 1, col) = _mm_extract_epi16(src.buf.reg[0], 1);
+ *dst->data(row + 2, col) = _mm_extract_epi16(src.buf.reg[0], 2);
+ *dst->data(row + 3, col) = _mm_extract_epi16(src.buf.reg[0], 3);
+ *dst->data(row + 4, col) = _mm_extract_epi16(src.buf.reg[0], 4);
+ *dst->data(row + 5, col) = _mm_extract_epi16(src.buf.reg[0], 5);
+ *dst->data(row + 6, col) = _mm_extract_epi16(src.buf.reg[0], 6);
+ *dst->data(row + 7, col) = _mm_extract_epi16(src.buf.reg[0], 7);
+ }
+ }
+};
+
inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
__m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]);
__m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]);
@@ -170,6 +276,21 @@ struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
+ static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
+ int col) {
+ std::int16_t buf[16];
+ StoreInt16x8(buf + 0, src.buf.reg[0]);
+ StoreInt16x8(buf + 8, src.buf.reg[1]);
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ *dst->data(row + i, col + j) = buf[i + 4 * j];
+ }
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
int col) {
@@ -202,6 +323,29 @@ struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
+ static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 4; i++) {
+ StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ std::int16_t buf[32];
+ StoreInt16x8(buf + 0, src.buf.reg[0]);
+ StoreInt16x8(buf + 8, src.buf.reg[1]);
+ StoreInt16x8(buf + 16, src.buf.reg[2]);
+ StoreInt16x8(buf + 24, src.buf.reg[3]);
+ for (int i = 0; i < 8; i++) {
+ for (int j = 0; j < 4; j++) {
+ *dst->data(row + i, col + j) = buf[i + 8 * j];
+ }
+ }
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
int col) {
@@ -255,6 +399,48 @@ struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
};
template <typename DstType>
+struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
+ static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
+ int col) {
+ if (DstType::kOrder == MapOrder::ColMajor) {
+ for (int i = 0; i < 8; i++) {
+ StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
+ }
+ } else {
+ // top-left 4x4
+ __m128i t0 = _mm_unpacklo_epi16(src.buf.reg[0], src.buf.reg[1]);
+ __m128i t1 = _mm_unpacklo_epi16(src.buf.reg[2], src.buf.reg[3]);
+ __m128i u0 = _mm_unpacklo_epi32(t0, t1);
+ __m128i u1 = _mm_unpackhi_epi32(t0, t1);
+ // top-right 4x4
+ __m128i t2 = _mm_unpacklo_epi16(src.buf.reg[4], src.buf.reg[5]);
+ __m128i t3 = _mm_unpacklo_epi16(src.buf.reg[6], src.buf.reg[7]);
+ __m128i u2 = _mm_unpacklo_epi32(t2, t3);
+ __m128i u3 = _mm_unpackhi_epi32(t2, t3);
+ // bottom-left 4x4
+ __m128i t4 = _mm_unpackhi_epi16(src.buf.reg[0], src.buf.reg[1]);
+ __m128i t5 = _mm_unpackhi_epi16(src.buf.reg[2], src.buf.reg[3]);
+ __m128i u4 = _mm_unpacklo_epi32(t4, t5);
+ __m128i u5 = _mm_unpackhi_epi32(t4, t5);
+ // bottom-right 4x4
+ __m128i t6 = _mm_unpackhi_epi16(src.buf.reg[4], src.buf.reg[5]);
+ __m128i t7 = _mm_unpackhi_epi16(src.buf.reg[6], src.buf.reg[7]);
+ __m128i u6 = _mm_unpacklo_epi32(t6, t7);
+ __m128i u7 = _mm_unpackhi_epi32(t6, t7);
+
+ StoreInt16x8(dst->data(row + 0, col), _mm_unpacklo_epi64(u0, u2));
+ StoreInt16x8(dst->data(row + 1, col), _mm_unpackhi_epi64(u0, u2));
+ StoreInt16x8(dst->data(row + 2, col), _mm_unpacklo_epi64(u1, u3));
+ StoreInt16x8(dst->data(row + 3, col), _mm_unpackhi_epi64(u1, u3));
+ StoreInt16x8(dst->data(row + 4, col), _mm_unpacklo_epi64(u4, u6));
+ StoreInt16x8(dst->data(row + 5, col), _mm_unpackhi_epi64(u4, u6));
+ StoreInt16x8(dst->data(row + 6, col), _mm_unpacklo_epi64(u5, u7));
+ StoreInt16x8(dst->data(row + 7, col), _mm_unpackhi_epi64(u5, u7));
+ }
+ }
+};
+
+template <typename DstType>
struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
int col) {
diff --git a/internal/pack.h b/internal/pack.h
index 3395396..cb4b93a 100644
--- a/internal/pack.h
+++ b/internal/pack.h
@@ -430,6 +430,8 @@ void PackRhs(PackedSideBlock* dst, const MatrixMapType& src) {
#include "pack_neon.h"
#elif defined(GEMMLOWP_SSE4)
#include "pack_sse.h"
+#elif defined(GEMMLOWP_MSA)
+#include "pack_msa.h"
#endif
#endif // GEMMLOWP_INTERNAL_PACK_H_
diff --git a/internal/pack_msa.h b/internal/pack_msa.h
new file mode 100644
index 0000000..fba8a0f
--- /dev/null
+++ b/internal/pack_msa.h
@@ -0,0 +1,353 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// pack_msa.h: optimized MSA specializations of the templates in pack.h.
+
+#ifndef GEMMLOWP_INTERNAL_PACK_MSA_H_
+#define GEMMLOWP_INTERNAL_PACK_MSA_H_
+
+#include "pack.h"
+
+#include <msa.h>
+
+namespace gemmlowp {
+
+typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
+ WidthMajorUint8SideMap;
+
+template <int Cells>
+using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
+
+template <int Cells>
+class PackingRegisterBlock<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>>
+ : public PackingRegisterBlockBase<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> {
+ public:
+ typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
+ typedef typename KernelSideFormat::Cell CellFormat;
+ static constexpr int kCells = KernelSideFormat::kCells;
+ static const int kCellWidth = CellFormat::kWidth;
+ static const int kKernelWidth = CellFormat::kWidth * kCells;
+ static const int kCellDepth = CellFormat::kDepth;
+ static const int kCellSize = CellFormat::kSize;
+
+ void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+ std::uint8_t* dst_ptr = dst->current_data();
+ const std::uint8_t* const src_ptr = this->complete_src_.data();
+ const int stride = this->complete_src_.stride();
+ // Load source WidthMajor data
+ v16i8 src_lines[4 * kCells];
+ for (int i = 0; i < 4 * kCells; i++) {
+ src_lines[i] = __builtin_msa_ld_b(
+ const_cast<std::uint8_t*>(src_ptr + i * stride), 0);
+ }
+ // Reorder the data within registers to make DepthMajor 4x2 cells
+ v16i8 src_lines_intertwined_2x[2 * kCells][2];
+ for (int i = 0; i < kCells; i++) {
+ src_lines_intertwined_2x[2 * i][0] =
+ __builtin_msa_ilvr_b(src_lines[4 * i + 2], src_lines[4 * i]);
+ src_lines_intertwined_2x[2 * i][1] =
+ __builtin_msa_ilvl_b(src_lines[4 * i + 2], src_lines[4 * i]);
+ src_lines_intertwined_2x[2 * i + 1][0] =
+ __builtin_msa_ilvr_b(src_lines[4 * i + 3], src_lines[4 * i + 1]);
+ src_lines_intertwined_2x[2 * i + 1][1] =
+ __builtin_msa_ilvl_b(src_lines[4 * i + 3], src_lines[4 * i + 1]);
+ }
+ v16i8 src_lines_intertwined_4x[2 * kCells][2];
+ for (int i = 0; i < kCells; i++) {
+ src_lines_intertwined_4x[2 * i][0] =
+ __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][0],
+ src_lines_intertwined_2x[2 * i][0]);
+ src_lines_intertwined_4x[2 * i][1] =
+ __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][0],
+ src_lines_intertwined_2x[2 * i][0]);
+ src_lines_intertwined_4x[2 * i + 1][0] =
+ __builtin_msa_ilvr_b(src_lines_intertwined_2x[2 * i + 1][1],
+ src_lines_intertwined_2x[2 * i][1]);
+ src_lines_intertwined_4x[2 * i + 1][1] =
+ __builtin_msa_ilvl_b(src_lines_intertwined_2x[2 * i + 1][1],
+ src_lines_intertwined_2x[2 * i][1]);
+ }
+ // Store the resulting DepthMajor 4x2 cells in the destination packed block
+ for (int outer = 0; outer < 2; outer++) {
+ for (int inner = 0; inner < 2; inner++) {
+ if (kCells % 2 == 0) {
+ for (int cell = 0; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ for (int cell = 0; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ } else {
+ // Store even number of low vector halves.
+ for (int cell = 0; cell < kCells - 1; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ // Store last low half and first high half.
+ v2i64 tmp = reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * 0 + outer][inner]);
+ tmp = __builtin_msa_insve_d(
+ tmp, 0,
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ // Store even number of high vector halves.
+ for (int cell = 1; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ }
+ }
+ }
+ // Compute sums across the depth dimension
+ v8i16 sums_of_2_cells[kCells][4];
+ const v16i8 zeroes = __builtin_msa_ldi_b(0);
+ for (int outer = 0; outer < 2; outer++) {
+ for (int inner = 0; inner < 2; inner++) {
+ int i = 2 * outer + inner;
+ for (int cell = 0; cell < kCells; cell++) {
+ v8i16 tmp0 = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(
+ zeroes, src_lines_intertwined_4x[2 * cell + outer][inner]));
+ v8i16 tmp1 = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(
+ zeroes, src_lines_intertwined_4x[2 * cell + outer][inner]));
+ sums_of_2_cells[cell][i] = __builtin_msa_addv_h(tmp0, tmp1);
+ }
+ }
+ }
+ v4i32 sums_of_4_cells[kCells][4];
+ for (int i = 0; i < 4; i++) {
+ for (int cell = 0; cell < kCells; cell++) {
+ v4i32 tmp0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(
+ reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i]));
+ v4i32 tmp1 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(
+ reinterpret_cast<v8i16>(zeroes), sums_of_2_cells[cell][i]));
+ sums_of_4_cells[cell][i] = __builtin_msa_addv_w(tmp0, tmp1);
+ }
+ }
+ // Update the sums_of_each_slice vector
+ for (int cell = 0; cell < kCells; cell++) {
+ v4i32 s01 = __builtin_msa_addv_w(sums_of_4_cells[cell][0],
+ sums_of_4_cells[cell][1]);
+ v4i32 s23 = __builtin_msa_addv_w(sums_of_4_cells[cell][2],
+ sums_of_4_cells[cell][3]);
+ v4i32 s = __builtin_msa_addv_w(s01, s23);
+ std::int32_t* sums_of_each_slice_ptr =
+ dst->sums_of_each_slice() + start_width + 4 * cell;
+ v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0);
+ tmp = __builtin_msa_addv_w(tmp, s);
+ __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0);
+ }
+ dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
+ }
+};
+
+template <int Cells>
+using WidthMajorSideFormatNCells4x2 =
+ KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
+
+template <int Cells>
+class PackingRegisterBlock<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
+ : public PackingRegisterBlockBase<
+ WidthMajorUint8SideMap,
+ PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
+ public:
+ typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
+ typedef typename KernelSideFormat::Cell CellFormat;
+ static constexpr int kCells = KernelSideFormat::kCells;
+ static const int kCellWidth = CellFormat::kWidth;
+ static const int kKernelWidth = CellFormat::kWidth * kCells;
+ static const int kCellDepth = CellFormat::kDepth;
+ static const int kCellSize = CellFormat::kSize;
+
+ void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
+ std::uint8_t* dst_ptr = dst->current_data();
+ const std::uint8_t* src_ptr = this->complete_src_.data();
+ const int stride = this->complete_src_.stride();
+ // Load source WidthMajor data
+ v8i16 src_lines[kCells * 4];
+ for (int i = 0; i < kCells; i++) {
+#define GEMMLOWP_UNROLLED_LOOP_ITER(k) \
+ src_lines[4 * i + k] = \
+ __builtin_msa_ld_h(const_cast<std::uint8_t*>(src_ptr), 0); \
+ src_ptr += stride;
+
+ GEMMLOWP_UNROLLED_LOOP_ITER(0)
+ GEMMLOWP_UNROLLED_LOOP_ITER(1)
+ GEMMLOWP_UNROLLED_LOOP_ITER(2)
+ GEMMLOWP_UNROLLED_LOOP_ITER(3)
+
+#undef GEMMLOWP_UNROLLED_LOOP_ITER
+ }
+ // Reorder the data within registers to make WidthMajor 4x2 cells
+ v8i16 src_lines_intertwined_2x[2 * kCells][2];
+ for (int i = 0; i < kCells; i++) {
+ src_lines_intertwined_2x[2 * i][0] =
+ __builtin_msa_ilvr_h(src_lines[4 * i + 2], src_lines[4 * i]);
+ src_lines_intertwined_2x[2 * i][1] =
+ __builtin_msa_ilvl_h(src_lines[4 * i + 2], src_lines[4 * i]);
+ src_lines_intertwined_2x[2 * i + 1][0] =
+ __builtin_msa_ilvr_h(src_lines[4 * i + 3], src_lines[4 * i + 1]);
+ src_lines_intertwined_2x[2 * i + 1][1] =
+ __builtin_msa_ilvl_h(src_lines[4 * i + 3], src_lines[4 * i + 1]);
+ }
+ v8i16 src_lines_intertwined_4x[2 * kCells][2];
+ for (int i = 0; i < kCells; i++) {
+ src_lines_intertwined_4x[2 * i][0] =
+ __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][0],
+ src_lines_intertwined_2x[2 * i][0]);
+ src_lines_intertwined_4x[2 * i][1] =
+ __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][0],
+ src_lines_intertwined_2x[2 * i][0]);
+ src_lines_intertwined_4x[2 * i + 1][0] =
+ __builtin_msa_ilvr_h(src_lines_intertwined_2x[2 * i + 1][1],
+ src_lines_intertwined_2x[2 * i][1]);
+ src_lines_intertwined_4x[2 * i + 1][1] =
+ __builtin_msa_ilvl_h(src_lines_intertwined_2x[2 * i + 1][1],
+ src_lines_intertwined_2x[2 * i][1]);
+ }
+ // Store the resulting WidthMajor 4x2 cells in the destination packed block
+ for (int outer = 0; outer < 2; outer++) {
+ for (int inner = 0; inner < 2; inner++) {
+ if (kCells % 2 == 0) {
+ for (int cell = 0; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ for (int cell = 0; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ } else {
+ // Store even number of low vector halves.
+ for (int cell = 0; cell < kCells - 1; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvr_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ // Store last low half and first high half.
+ v2i64 tmp = reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * 0 + outer][inner]);
+ tmp = __builtin_msa_insve_d(
+ tmp, 0,
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (kCells - 1) + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ // Store even number of high vector halves.
+ for (int cell = 1; cell < kCells; cell += 2) {
+ v2i64 tmp = __builtin_msa_ilvl_d(
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * (cell + 1) + outer][inner]),
+ reinterpret_cast<v2i64>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]));
+ __builtin_msa_st_b(reinterpret_cast<v16i8>(tmp), dst_ptr, 0);
+ dst_ptr += 16;
+ }
+ }
+ }
+ }
+ // Compute sums across the depth dimension
+ v8i16 sums_of_2[kCells][4];
+ for (int outer = 0; outer < 2; outer++) {
+ for (int inner = 0; inner < 2; inner++) {
+ int i = 2 * outer + inner;
+ for (int cell = 0; cell < kCells; cell++) {
+ sums_of_2[cell][i] = reinterpret_cast<v8i16>(__builtin_msa_hadd_u_h(
+ reinterpret_cast<v16u8>(
+ src_lines_intertwined_4x[2 * cell + outer][inner]),
+ reinterpret_cast<v16u8>(
+ src_lines_intertwined_4x[2 * cell + outer][inner])));
+ }
+ }
+ }
+ v8i16 sums_of_4[kCells][2];
+ for (int i = 0; i < 2; i++) {
+ for (int cell = 0; cell < kCells; cell++) {
+ sums_of_4[cell][i] = __builtin_msa_addv_h(sums_of_2[cell][2 * i],
+ sums_of_2[cell][2 * i + 1]);
+ }
+ }
+ v8i16 sums_of_8[kCells];
+ for (int cell = 0; cell < kCells; cell++) {
+ sums_of_8[cell] =
+ __builtin_msa_addv_h(sums_of_4[cell][0], sums_of_4[cell][1]);
+ }
+
+ v4i32 sums_of_16[kCells];
+ const v8i16 zeroes = __builtin_msa_ldi_h(0);
+ for (int cell = 0; cell < kCells; cell++) {
+ sums_of_16[cell] = reinterpret_cast<v4i32>(
+ __builtin_msa_ilvr_h(zeroes, sums_of_8[cell]));
+ v8i16 tmp = __builtin_msa_ilvl_h(zeroes, sums_of_8[cell]);
+ sums_of_16[cell] =
+ __builtin_msa_addv_w(sums_of_16[cell], reinterpret_cast<v4i32>(tmp));
+ }
+ // Update the sums_of_each_slice vector
+ for (int cell = 0; cell < kCells; cell++) {
+ std::int32_t* sums_of_each_slice_ptr =
+ dst->sums_of_each_slice() + start_width + 4 * cell;
+ v4i32 tmp = __builtin_msa_ld_w(sums_of_each_slice_ptr, 0);
+ tmp = __builtin_msa_addv_w(tmp, sums_of_16[cell]);
+ __builtin_msa_st_w(tmp, sums_of_each_slice_ptr, 0);
+ }
+ dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
+ }
+};
+
+} // namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_PACK_MSA_H_
diff --git a/internal/pack_neon.h b/internal/pack_neon.h
index e212d07..2b08464 100644
--- a/internal/pack_neon.h
+++ b/internal/pack_neon.h
@@ -153,10 +153,10 @@ class PackingRegisterBlock<
// Load source WidthMajor data
uint16x8_t src_lines[kCells * 4];
for (int i = 0; i < kCells; i++) {
-// This packing path is used with our current
-// less-than-8-bit kernel, and the partial unrolling of this loop
-// results in substantially faster code (thanks to better
-// register allocation) on Nexus 5.
+ // This packing path is used with our current
+ // less-than-8-bit kernel, and the partial unrolling of this loop
+ // results in substantially faster code (thanks to better
+ // register allocation) on Nexus 5.
#define GEMMLOWP_UNROLLED_LOOP_ITER(k) \
src_lines[4 * i + k] = vreinterpretq_u16_u8(vld1q_u8(src_ptr)); \
diff --git a/internal/platform.h b/internal/platform.h
index 49e41a9..1114767 100755..100644
--- a/internal/platform.h
+++ b/internal/platform.h
@@ -17,17 +17,20 @@
#ifndef GEMMLOWP_INTERNAL_PLATFORM_H_
#define GEMMLOWP_INTERNAL_PLATFORM_H_
-
#ifdef _WIN32
#include <windows.h>
#else
-#include <unistd.h>
-#include <time.h>
#include <stdlib.h>
+#include <time.h>
+#include <unistd.h>
+#endif
+
+#ifdef __APPLE__
+#include <sys/time.h>
#endif
-#include <malloc.h>
#if defined ANDROID || defined __ANDROID__
+#include <malloc.h>
#include <android/api-level.h>
// The 18 here should be 16, but has to be 18 for now due
// to a Google-internal issue.
@@ -42,6 +45,10 @@
#endif
#endif
+// Needed by chrome native builds
+#ifndef _SC_NPROCESSORS_CONF
+#define _SC_NPROCESSORS_CONF _SC_NPROCESSORS_ONLN
+#endif
namespace gemmlowp {
@@ -50,9 +57,7 @@ inline void *aligned_alloc(size_t alignment, size_t size) {
return _aligned_malloc(size, alignment);
}
-inline void aligned_free(void *memptr) {
- _aligned_free(memptr);
-}
+inline void aligned_free(void *memptr) { _aligned_free(memptr); }
inline int GetHardwareConcurrency(int max_threads) {
if (max_threads == 0) {
@@ -64,8 +69,9 @@ inline int GetHardwareConcurrency(int max_threads) {
}
inline double real_time_in_seconds() {
- __int64 wintime; GetSystemTimeAsFileTime((FILETIME*)&wintime);
- wintime -= 116444736000000000i64; //1jan1601 to 1jan1970
+ __int64 wintime;
+ GetSystemTimeAsFileTime((FILETIME *)&wintime);
+ wintime -= 116444736000000000i64; // 1jan1601 to 1jan1970
return wintime / 10000000i64 + wintime % 10000000i64 * 100 * 1e-9;
}
@@ -91,9 +97,7 @@ inline int GetHardwareConcurrency(int max_threads) {
return max_threads;
}
-inline void aligned_free(void *memptr) {
- free(memptr);
-}
+inline void aligned_free(void *memptr) { free(memptr); }
inline double real_time_in_seconds() {
#ifdef __APPLE__
@@ -108,5 +112,5 @@ inline double real_time_in_seconds() {
}
#endif
-} // namespace gemmlowp
+} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_PLATFORM_H_
diff --git a/internal/simd_wrappers.h b/internal/simd_wrappers.h
index e39eaf8..d9721c9 100644
--- a/internal/simd_wrappers.h
+++ b/internal/simd_wrappers.h
@@ -491,10 +491,14 @@ void AddConstant(RegisterBlockType* block) {
template <int N>
using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
template <int N>
+using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
+template <int N>
using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
template <int R, int C>
using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
template <int R, int C>
+using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
+template <int R, int C>
using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
} // end namespace gemmlowp
@@ -503,6 +507,8 @@ using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
#include "simd_wrappers_neon.h"
#elif defined GEMMLOWP_SSE4
#include "simd_wrappers_sse.h"
+#elif defined GEMMLOWP_MSA
+#include "simd_wrappers_msa.h"
#endif
#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
diff --git a/internal/simd_wrappers_msa.h b/internal/simd_wrappers_msa.h
new file mode 100644
index 0000000..cf5e8e9
--- /dev/null
+++ b/internal/simd_wrappers_msa.h
@@ -0,0 +1,196 @@
+// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// simd_wrappers_msa.h: MSA specialization of simd_wrappers.h
+
+#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
+#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
+
+#include <msa.h>
+
+namespace gemmlowp {
+
+using Int32x4 = v4i32;
+using Int16x8 = v8i16;
+using Uint8x16 = v16i8;
+
+template <int ScalarCount>
+struct RegisterType<std::int32_t, ScalarCount> {
+ using Type =
+ typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type;
+};
+
+template <int ScalarCount>
+struct RegisterType<std::int16_t, ScalarCount> {
+ using Type =
+ typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
+};
+
+template <int ScalarCount>
+struct RegisterType<std::uint8_t, ScalarCount> {
+ using Type = typename std::conditional<
+ ScalarCount >= 16, Uint8x16,
+ typename std::conditional<ScalarCount >= 4, std::uint32_t,
+ std::uint8_t>::type>::type;
+};
+
+inline Int32x4 LoadInt32x4(const std::int32_t* src) {
+ return __builtin_msa_ld_w(const_cast<std::int32_t*>(src), 0);
+}
+
+inline Int32x4 LoadInt32x4(const Int32x4* src) {
+ return __builtin_msa_ld_w(const_cast<Int32x4*>(src), 0);
+}
+
+inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
+ __builtin_msa_st_w(value, dst, 0);
+}
+
+inline void StoreInt32x4(Int32x4* dst, Int32x4 value) {
+ __builtin_msa_st_w(value, dst, 0);
+}
+
+inline Int16x8 LoadInt16x8(const std::int16_t* src) {
+ return __builtin_msa_ld_h(const_cast<std::int16_t*>(src), 0);
+}
+
+inline Int16x8 LoadInt16x8(const Int16x8* src) {
+ return __builtin_msa_ld_h(const_cast<Int16x8*>(src), 0);
+}
+
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
+ __builtin_msa_st_h(value, dst, 0);
+}
+
+inline void StoreInt16x8(Int16x8* dst, Int16x8 value) {
+ __builtin_msa_st_h(value, dst, 0);
+}
+
+inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
+ return __builtin_msa_ld_b(const_cast<std::uint8_t*>(src), 0);
+}
+
+inline Uint8x16 LoadUint8x16(const Uint8x16* src) {
+ return __builtin_msa_ld_b(const_cast<Uint8x16*>(src), 0);
+}
+
+inline void StoreUint8x16(std::uint8_t* dst, Uint8x16 value) {
+ __builtin_msa_st_b(value, dst, 0);
+}
+
+inline void StoreUint8x16(Uint8x16* dst, Uint8x16 value) {
+ __builtin_msa_st_b(value, dst, 0);
+}
+
+template <int Lane>
+std::int32_t GetLane(Int32x4 value) {
+ return __builtin_msa_copy_s_w(value, Lane);
+}
+
+template <int Lane>
+Int32x4 DupLane(Int32x4 value) {
+ static_assert(Lane >= 0 && Lane <= 3, "");
+ return __builtin_msa_splati_w(value, Lane);
+}
+
+inline Int32x4 Mul(Int32x4 a, std::int32_t b) {
+ return __builtin_msa_mulv_w(a, __builtin_msa_fill_w(b));
+}
+
+inline Int32x4 Min(Int32x4 a, Int32x4 b) { return __builtin_msa_min_s_w(a, b); }
+
+inline Int32x4 Max(Int32x4 a, Int32x4 b) { return __builtin_msa_max_s_w(a, b); }
+
+inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
+ return __builtin_msa_mulr_q_w(a, __builtin_msa_fill_w(b));
+}
+
+template <int Lane>
+Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) {
+ static_assert(Lane >= 0 && Lane <= 3, "");
+ return __builtin_msa_mulv_w(a, __builtin_msa_splati_w(b, Lane));
+}
+
+static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) {
+ // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c).
+#if 0
+ return __builtin_msa_maddv_w(a, b, c);
+#else
+ asm volatile("maddv.w %w[a], %w[b], %w[c]\n"
+ // Outputs
+ : [a] "+f"(a)
+ // Inputs
+ : [b] "f"(b), [c] "f"(c));
+ return a;
+#endif
+}
+
+inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
+ Int32x4 tmp = LoadInt32x4(acc);
+ tmp = workaround_msa_maddv_w(tmp, lhs, rhs);
+ StoreInt32x4(acc, tmp);
+}
+
+inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) {
+ Int32x4 tmp = LoadInt32x4(acc);
+ tmp = workaround_msa_maddv_w(tmp, lhs, __builtin_msa_fill_w(rhs));
+ StoreInt32x4(acc, tmp);
+}
+
+template <int Lane>
+inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
+ static_assert(Lane >= 0 && Lane <= 3, "");
+ Int32x4 tmp = LoadInt32x4(acc);
+ tmp = workaround_msa_maddv_w(tmp, lhs, __builtin_msa_splati_w(rhs, Lane));
+ StoreInt32x4(acc, tmp);
+}
+
+template <>
+struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
+ static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
+ RegBlockUint8<8, 8> result;
+ for (int i = 0; i < 4; i++) {
+ result.buf.reg[i] = LoadUint8x16(src + 16 * i);
+ }
+ return result;
+ }
+};
+
+template <>
+struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
+ static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
+ RegBlockInt32<8, 8> result;
+ for (int i = 0; i < 16; i++) {
+ result.buf.reg[i] = LoadInt32x4(src + 4 * i);
+ }
+ return result;
+ }
+};
+
+template <>
+struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
+ static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
+ RegBlockInt16<8, 8> result;
+ for (int i = 0; i < 8; i++) {
+ result.buf.reg[i] = LoadInt16x8(src + 8 * i);
+ }
+ return result;
+ }
+};
+
+} // end namespace gemmlowp
+
+#include "simd_wrappers_common_neon_sse.h"
+
+#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
diff --git a/internal/simd_wrappers_neon.h b/internal/simd_wrappers_neon.h
index c992b15..2949173 100644
--- a/internal/simd_wrappers_neon.h
+++ b/internal/simd_wrappers_neon.h
@@ -22,6 +22,8 @@
namespace gemmlowp {
using Int32x4 = int32x4_t;
+using Int16x4 = int16x4_t;
+using Int16x8 = int16x8_t;
using Uint8x8 = uint8x8_t;
template <int ScalarCount>
@@ -31,6 +33,14 @@ struct RegisterType<std::int32_t, ScalarCount> {
};
template <int ScalarCount>
+struct RegisterType<std::int16_t, ScalarCount> {
+ using Type = typename std::conditional<
+ ScalarCount >= 8, Int16x8,
+ typename std::conditional<ScalarCount >= 4, Int16x4,
+ std::int16_t>::type>::type;
+};
+
+template <int ScalarCount>
struct RegisterType<std::uint8_t, ScalarCount> {
using Type = typename std::conditional<
ScalarCount >= 8, Uint8x8,
@@ -39,11 +49,21 @@ struct RegisterType<std::uint8_t, ScalarCount> {
};
inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
+inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); }
+inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); }
inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
vst1q_s32(dst, value);
}
+inline void StoreInt16x4(std::int16_t* dst, Int16x4 value) {
+ vst1_s16(dst, value);
+}
+
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
+ vst1q_s16(dst, value);
+}
+
template <int Lane>
std::int32_t GetLane(Int32x4 value) {
return vgetq_lane_s32(value, Lane);
@@ -122,6 +142,17 @@ inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
}
template <>
+struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
+ static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
+ RegBlockInt16<8, 8> result;
+ for (int i = 0; i < 8; i++) {
+ result.buf.reg[i] = vld1q_s16(src + 8 * i);
+ }
+ return result;
+ }
+};
+
+template <>
struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
RegBlockUint8<8, 8> result;
diff --git a/internal/simd_wrappers_sse.h b/internal/simd_wrappers_sse.h
index 6480b66..3b78cb4 100644
--- a/internal/simd_wrappers_sse.h
+++ b/internal/simd_wrappers_sse.h
@@ -22,6 +22,7 @@
namespace gemmlowp {
using Int32x4 = __m128i;
+using Int16x8 = __m128i;
using Uint8x16 = __m128i;
template <int ScalarCount>
@@ -31,6 +32,12 @@ struct RegisterType<std::int32_t, ScalarCount> {
};
template <int ScalarCount>
+struct RegisterType<std::int16_t, ScalarCount> {
+ using Type =
+ typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
+};
+
+template <int ScalarCount>
struct RegisterType<std::uint8_t, ScalarCount> {
using Type = typename std::conditional<
ScalarCount >= 16, Uint8x16,
@@ -42,10 +49,18 @@ inline Int32x4 LoadInt32x4(const std::int32_t* src) {
return _mm_loadu_si128(reinterpret_cast<const Int32x4*>(src));
}
+inline Int32x4 LoadInt16x8(const std::int16_t* src) {
+ return _mm_loadu_si128(reinterpret_cast<const Int16x8*>(src));
+}
+
inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value);
}
+inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
+ _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), value);
+}
+
inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
return _mm_loadu_si128(reinterpret_cast<const Uint8x16*>(src));
}
@@ -116,6 +131,17 @@ struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
}
};
+template <>
+struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
+ static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
+ RegBlockInt16<8, 8> result;
+ for (int i = 0; i < 8; i++) {
+ result.buf.reg[i] = LoadInt16x8(src + 8 * i);
+ }
+ return result;
+ }
+};
+
} // end namespace gemmlowp
#include "simd_wrappers_common_neon_sse.h"
diff --git a/internal/single_thread_gemm.h b/internal/single_thread_gemm.h
index 3d430c5..35a7835 100644
--- a/internal/single_thread_gemm.h
+++ b/internal/single_thread_gemm.h
@@ -89,10 +89,9 @@ void SingleThreadGemm(SingleThreadGemmContext* context,
Allocator* allocator = context->allocator();
BlockParams block_params;
- block_params.Init<KernelFormat>(rows, cols, depth, 1,
- context->l1_bytes_to_use(),
- context->l2_bytes_to_use(),
- context->l2_rhs_factor());
+ block_params.Init<KernelFormat>(
+ rows, cols, depth, 1, context->l1_bytes_to_use(),
+ context->l2_bytes_to_use(), context->l2_rhs_factor());
#ifdef GEMMLOWP_PROFILING_SIZES
// Using a static map of label strings. Not reentrant at all!
diff --git a/meta/multi_thread_common.h b/meta/multi_thread_common.h
index dc1b799..0b35759 100644
--- a/meta/multi_thread_common.h
+++ b/meta/multi_thread_common.h
@@ -20,6 +20,15 @@
namespace gemmlowp {
namespace meta {
+inline int ResolveMaxThreads(int max_threads) {
+ if (max_threads == 0) {
+ static const int hardware_threads_count =
+ static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
+ return hardware_threads_count;
+ }
+ return max_threads;
+}
+
template <typename WorkersPool>
class SimpleContext {
public:
diff --git a/profiling/instrumentation.h b/profiling/instrumentation.h
index 539076a..437fe54 100644
--- a/profiling/instrumentation.h
+++ b/profiling/instrumentation.h
@@ -24,7 +24,6 @@
#ifndef GEMMLOWP_PROFILING_INSTRUMENTATION_H_
#define GEMMLOWP_PROFILING_INSTRUMENTATION_H_
-#include <pthread.h>
#include <cstdio>
#ifndef GEMMLOWP_USE_STLPORT
@@ -32,15 +31,15 @@
#else
#include <stdint.h>
namespace std {
-using ::uint8_t;
-using ::uint16_t;
-using ::uint32_t;
-using ::int8_t;
using ::int16_t;
using ::int32_t;
+using ::int8_t;
using ::size_t;
+using ::uint16_t;
+using ::uint32_t;
+using ::uint8_t;
using ::uintptr_t;
-}
+} // namespace std
#endif
#include <algorithm>
@@ -52,6 +51,8 @@ using ::uintptr_t;
#include <set>
#endif
+#include "./pthread_everywhere.h"
+
namespace gemmlowp {
inline void ReleaseBuildAssertion(bool condition, const char* msg) {
diff --git a/profiling/pthread_everywhere.h b/profiling/pthread_everywhere.h
index 7e12d66..df17c6f 100644
--- a/profiling/pthread_everywhere.h
+++ b/profiling/pthread_everywhere.h
@@ -18,8 +18,6 @@
#ifndef GEMMLOWP_PROFILING_PTHREAD_EVERYWHERE_H_
#define GEMMLOWP_PROFILING_PTHREAD_EVERYWHERE_H_
-#include "pthread_everywhere.h"
-
#ifndef _WIN32
#define GEMMLOWP_USE_PTHREAD
#endif
@@ -39,39 +37,29 @@
// structs; ours take nullptr_t. That is because gemmlowp always passes
// nullptr at the moment, so any support we would code for non-null
// attribs would be unused.
-#include <thread>
-#include <mutex>
#include <condition_variable>
#include <cstddef>
+#include <mutex>
+#include <thread>
namespace gemmlowp {
-using pthread_t = std::thread*;
-using pthread_mutex_t = std::mutex*;
-using pthread_cond_t = std::condition_variable*;
-inline void pthread_create(pthread_t* thread, std::nullptr_t,
- void *(*start_routine) (void *), void *arg) {
+using pthread_t = std::thread *;
+using pthread_mutex_t = std::mutex *;
+using pthread_cond_t = std::condition_variable *;
+inline void pthread_create(pthread_t *thread, std::nullptr_t,
+ void *(*start_routine)(void *), void *arg) {
*thread = new std::thread(start_routine, arg);
}
-inline void pthread_join(pthread_t thread, std::nullptr_t) {
- thread->join();
-}
+inline void pthread_join(pthread_t thread, std::nullptr_t) { thread->join(); }
inline void pthread_mutex_init(pthread_mutex_t *mutex, std::nullptr_t) {
*mutex = new std::mutex;
}
-inline void pthread_mutex_lock(pthread_mutex_t* mutex) {
- (*mutex)->lock();
-}
-inline void pthread_mutex_unlock(pthread_mutex_t* mutex) {
- (*mutex)->unlock();
-}
-inline void pthread_mutex_destroy(pthread_mutex_t *mutex) {
- delete *mutex;
-}
+inline void pthread_mutex_lock(pthread_mutex_t *mutex) { (*mutex)->lock(); }
+inline void pthread_mutex_unlock(pthread_mutex_t *mutex) { (*mutex)->unlock(); }
+inline void pthread_mutex_destroy(pthread_mutex_t *mutex) { delete *mutex; }
inline void pthread_cond_init(pthread_cond_t *cond, std::nullptr_t) {
*cond = new std::condition_variable;
}
-inline void pthread_cond_signal(pthread_cond_t* cond) {
- (*cond)->notify_one();
-}
+inline void pthread_cond_signal(pthread_cond_t *cond) { (*cond)->notify_one(); }
inline void pthread_cond_wait(pthread_cond_t *cond, pthread_mutex_t *mutex) {
std::unique_lock<std::mutex> lock(**mutex, std::adopt_lock);
(*cond)->wait(lock);
@@ -79,10 +67,8 @@ inline void pthread_cond_wait(pthread_cond_t *cond, pthread_mutex_t *mutex) {
// the lock is not released
lock.release();
}
-inline void pthread_cond_destroy(pthread_cond_t *cond) {
- delete *cond;
-}
+inline void pthread_cond_destroy(pthread_cond_t *cond) { delete *cond; }
} // end namespace gemmlowp
#endif
-#endif // GEMMLOWP_PROFILING_PTHREAD_EVERYWHERE_H_ \ No newline at end of file
+#endif // GEMMLOWP_PROFILING_PTHREAD_EVERYWHERE_H_
diff --git a/public/output_stages.h b/public/output_stages.h
index 23bcdc0..1d5fca4 100644
--- a/public/output_stages.h
+++ b/public/output_stages.h
@@ -66,8 +66,9 @@ struct OutputStageQuantizeDownInt32ToUint8ScalePC {
};
// This output stage takes int32 values and returns still int32 values,
-// but "quantized down" to the uint8 scale; in other words, its output
-// is typically what one would then clamp to [0..255] and cast to uint8
+// but "quantized down" to a difference scale; for example, in a pipeline
+// that outputs uint8 values in [0..255], the output of this stage would be
+// int32 values ready to be clamped to [0..255] and casted to uint8
// (see OutputStageSaturatingCastToUint8).
//
// This "quantization down" process depends on 3 parameters,
@@ -111,17 +112,42 @@ struct OutputStageQuantizeDownInt32ToUint8ScalePC {
// expansions that implicitly rely on 0-padding. If 0 were not
// a representable value, such operations would have to pad
// using a nonzero value, introducing bias in the computation.
-struct OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint {
+struct OutputStageQuantizeDownInt32ByFixedPoint {
std::int32_t result_fixedpoint_multiplier;
std::int32_t result_shift;
std::int32_t result_offset_after_shift;
};
+// OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint is the old deprecated
+// name of OutputStageQuantizeDownInt32ByFixedPoint, before we noticed that
+// there really wasn't anything Uint8-specific about it.
+using OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint = OutputStageQuantizeDownInt32ByFixedPoint;
+
+// Variant of OutputStageQuantizeDownInt32ByFixedPoint where the 'shift'
+// is not necessarily just a right shift, so we can represent multipliers
+// greater than 1. This takes an result_exponent parameter; when it's
+// <= 0, this is equivalent to OutputStageQuantizeDownInt32ByFixedPoint
+// with result_shift = -result_exponent.
+// In the general case, this consists in first left-shifting by
+// std::max(result_exponent, 0), before doing the same as
+// OutputStageQuantizeDownInt32ByFixedPoint with
+// result_shift = std::max(-result_exponent, 0).
+struct OutputStageScaleInt32ByFixedPointAndExponent {
+ std::int32_t result_fixedpoint_multiplier;
+ std::int32_t result_exponent;
+ std::int32_t result_offset_after_shift;
+};
+
// This output stage takes int32 values that are expected to be already
// on the final uint8 scale, but not necessarily in the [0..255] range.
// It clamps them to the [0..255] range and returns them casted to uint8.
struct OutputStageSaturatingCastToUint8 {};
+// This output stage takes int32 values that are expected to be already
+// on the final int16 scale, but not necessarily in the [-32768..32767] range.
+// It clamps them to the [-32768..32767] range and returns them casted to int16.
+struct OutputStageSaturatingCastToInt16 {};
+
// This output stage depends on a "bias vector" that should contain int32
// entries, and be either a row-vector of the same number of columns as the
// result matrix, or a column-vector of the same number of rows as the
diff --git a/scripts/ci-test.sh b/scripts/ci-test.sh
index de6e344..83cc5cd 100755
--- a/scripts/ci-test.sh
+++ b/scripts/ci-test.sh
@@ -11,4 +11,4 @@ if [ $TEST == "arm" ]; then
fi
if [ $TEST == "x86" ]; then
make -f Makefile.travis unittest
-fi
+fi
diff --git a/standalone/neon-gemm-kernel-benchmark.cc b/standalone/neon-gemm-kernel-benchmark.cc
index 2a936c1..bff33fb 100644
--- a/standalone/neon-gemm-kernel-benchmark.cc
+++ b/standalone/neon-gemm-kernel-benchmark.cc
@@ -61,15 +61,30 @@
#include <cassert>
#include <cstdint>
#include <cstdlib>
+#include <cstring>
#include <iostream>
#include <random>
#include <type_traits>
-#if !defined __arm__ && !defined __aarch64__
-#error This benchmark assumes ARM (for inline assembly sections).
+#if !defined(__arm__) && !defined(__aarch64__) && \
+ !(defined(__mips) && (__mips_isa_rev >= 5) && defined(__mips_msa))
+#error This benchmark assumes ARM or MIPS (for intrinsics and inline assembly sections).
#endif
+#if defined(__arm__) || defined(__aarch64__)
#include <arm_neon.h>
+#endif
+
+#if defined(__mips)
+#include <msa.h>
+
+// Some convenience macros to hide differences between MIPS32 and MIPS64.
+#ifdef __LP64__
+#define GEMMLOWP_MIPS_XADDIU "daddiu"
+#else
+#define GEMMLOWP_MIPS_XADDIU "addiu"
+#endif
+#endif
// Typically one wants to fit in L1 cache, and GEMM implementations
// are carefully optimized to tune their access patterns to that effect.
@@ -2501,6 +2516,291 @@ struct NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits {
}
};
+#ifdef __ARM_FEATURE_DOTPROD
+// Kernels utilizing the Armv8.2 Dot Product extension.
+//
+// The dot product instructions work by taking 4 consecutive 8-bit depth
+// values from each operand, multiplying the 4 pairs together and
+// accumulating all the results into the corresponding 32-bit accumulator
+// lane. As such, the operation is identical to a 32-bit instruction (like
+// FMLA used in SGEMM), except that 4 depth values are processed at a time
+// instead of 1.
+
+// Thus, this first kernel is a carbon copy of
+// "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good
+// performance for most processors) below with the opcode (fmla -> udot) and
+// types (float32 -> uint8/uint32) changed.
+//
+// A signed version of this kernel could be produced by replacing "udot"
+// with "sdot" - performance should be identical to this udot kernel.
+struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "mov x0, %[accum_ptr]\n"
+ "ld1 {v8.4s}, [x0], #16\n"
+ "ld1 {v16.4s}, [x0], #16\n"
+ "ld1 {v24.4s}, [x0], #16\n"
+ "ld1 {v9.4s}, [x0], #16\n"
+ "ld1 {v17.4s}, [x0], #16\n"
+ "ld1 {v25.4s}, [x0], #16\n"
+ "ld1 {v10.4s}, [x0], #16\n"
+ "ld1 {v18.4s}, [x0], #16\n"
+ "ld1 {v26.4s}, [x0], #16\n"
+ "ld1 {v11.4s}, [x0], #16\n"
+ "ld1 {v19.4s}, [x0], #16\n"
+ "ld1 {v27.4s}, [x0], #16\n"
+ "ld1 {v12.4s}, [x0], #16\n"
+ "ld1 {v20.4s}, [x0], #16\n"
+ "ld1 {v28.4s}, [x0], #16\n"
+ "ld1 {v13.4s}, [x0], #16\n"
+ "ld1 {v21.4s}, [x0], #16\n"
+ "ld1 {v29.4s}, [x0], #16\n"
+ "ld1 {v14.4s}, [x0], #16\n"
+ "ld1 {v22.4s}, [x0], #16\n"
+ "ld1 {v30.4s}, [x0], #16\n"
+ "ld1 {v15.4s}, [x0], #16\n"
+ "ld1 {v23.4s}, [x0], #16\n"
+ "ld1 {v31.4s}, [x0], #16\n"
+
+ // The start of the loop assumes first Rhs cell is already loaded, so
+ // do it here for first iteration.
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n"
+
+ // And the same for the first Lhs cell.
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+
+ // Start the MACs at the head of the loop - 1st cell from each side
+ // already loaded.
+ "udot v8.4s, v2.16b, v0.b[0]\n"
+ "udot v9.4s, v2.16b, v0.b[1]\n"
+ "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell.
+ "udot v10.4s, v2.16b, v0.b[2]\n"
+ "udot v11.4s, v2.16b, v0.b[3]\n"
+ "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell.
+ "udot v12.4s, v2.16b, v1.b[0]\n"
+ "udot v13.4s, v2.16b, v1.b[1]\n"
+ "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell.
+ "udot v14.4s, v2.16b, v1.b[2]\n"
+ "udot v15.4s, v2.16b, v1.b[3]\n"
+ "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load
+ // for the next iteration early.
+ "udot v16.4s, v3.16b, v0.b[0]\n"
+ "udot v17.4s, v3.16b, v0.b[1]\n"
+ "udot v18.4s, v3.16b, v0.b[2]\n"
+ "udot v19.4s, v3.16b, v0.b[3]\n"
+ "udot v20.4s, v3.16b, v1.b[0]\n"
+ "udot v21.4s, v3.16b, v1.b[1]\n"
+ "udot v22.4s, v3.16b, v1.b[2]\n"
+ "udot v23.4s, v3.16b, v1.b[3]\n"
+ "udot v24.4s, v4.16b, v0.b[0]\n"
+ "udot v25.4s, v4.16b, v0.b[1]\n"
+ "udot v26.4s, v4.16b, v0.b[2]\n"
+ "udot v27.4s, v4.16b, v0.b[3]\n"
+ "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell -
+ // load for the next iteration early.
+ "udot v28.4s, v4.16b, v1.b[0]\n"
+ "udot v29.4s, v4.16b, v1.b[1]\n"
+
+ // Loop. Decrement loop index (depth) by 4 as udot processes 4
+ // depth values.
+ "subs %w[depth], %w[depth], #4\n"
+ "udot v30.4s, v4.16b, v1.b[2]\n"
+ "udot v31.4s, v4.16b, v1.b[3]\n"
+
+ "bne " GEMMLOWP_LABEL_LOOP
+ "b\n"
+
+ // Store accumulators
+ "mov x0, %[accum_ptr]\n"
+ "st1 {v8.16b}, [x0], #16\n"
+ "st1 {v16.16b}, [x0], #16\n"
+ "st1 {v24.16b}, [x0], #16\n"
+ "st1 {v9.16b}, [x0], #16\n"
+ "st1 {v17.16b}, [x0], #16\n"
+ "st1 {v25.16b}, [x0], #16\n"
+ "st1 {v10.16b}, [x0], #16\n"
+ "st1 {v18.16b}, [x0], #16\n"
+ "st1 {v26.16b}, [x0], #16\n"
+ "st1 {v11.16b}, [x0], #16\n"
+ "st1 {v19.16b}, [x0], #16\n"
+ "st1 {v27.16b}, [x0], #16\n"
+ "st1 {v12.16b}, [x0], #16\n"
+ "st1 {v20.16b}, [x0], #16\n"
+ "st1 {v28.16b}, [x0], #16\n"
+ "st1 {v13.16b}, [x0], #16\n"
+ "st1 {v21.16b}, [x0], #16\n"
+ "st1 {v29.16b}, [x0], #16\n"
+ "st1 {v14.16b}, [x0], #16\n"
+ "st1 {v22.16b}, [x0], #16\n"
+ "st1 {v30.16b}, [x0], #16\n"
+ "st1 {v15.16b}, [x0], #16\n"
+ "st1 {v23.16b}, [x0], #16\n"
+ "st1 {v31.16b}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7",
+ "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17",
+ "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27",
+ "v28", "v29", "v30", "v31");
+ }
+};
+
+// As above, except tuned for Cortex-A55r1.
+//
+// Similarly, this is a clone of NEON_64bit_GEMM_Float32_WithScalar_A55r1
+// with the names changed.
+struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "mov x0, %[accum_ptr]\n"
+ "ld1 {v8.4s}, [x0], #16\n"
+ "ld1 {v16.4s}, [x0], #16\n"
+ "ld1 {v24.4s}, [x0], #16\n"
+ "ld1 {v9.4s}, [x0], #16\n"
+ "ld1 {v17.4s}, [x0], #16\n"
+ "ld1 {v25.4s}, [x0], #16\n"
+ "ld1 {v10.4s}, [x0], #16\n"
+ "ld1 {v18.4s}, [x0], #16\n"
+ "ld1 {v26.4s}, [x0], #16\n"
+ "ld1 {v11.4s}, [x0], #16\n"
+ "ld1 {v19.4s}, [x0], #16\n"
+ "ld1 {v27.4s}, [x0], #16\n"
+ "ld1 {v12.4s}, [x0], #16\n"
+ "ld1 {v20.4s}, [x0], #16\n"
+ "ld1 {v28.4s}, [x0], #16\n"
+ "ld1 {v13.4s}, [x0], #16\n"
+ "ld1 {v21.4s}, [x0], #16\n"
+ "ld1 {v29.4s}, [x0], #16\n"
+ "ld1 {v14.4s}, [x0], #16\n"
+ "ld1 {v22.4s}, [x0], #16\n"
+ "ld1 {v30.4s}, [x0], #16\n"
+ "ld1 {v15.4s}, [x0], #16\n"
+ "ld1 {v23.4s}, [x0], #16\n"
+ "ld1 {v31.4s}, [x0], #16\n"
+
+ // For details on how this kernel works, see the Float32 kernel below.
+
+ "ldr d0, [%[rhs_ptr]]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n"
+
+ "ldr q2, [%[lhs_ptr]]\n"
+ "ldr q3, [%[lhs_ptr], #16]\n"
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+
+ "udot v8.4s, v2.16b, v0.b[0]\n"
+ "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
+ "udot v9.4s, v2.16b, v0.b[1]\n"
+ "ins v0.d[1], x18\n" // Finish loading v0
+ "udot v16.4s, v3.16b, v0.b[0]\n" // out of sequence - used to reduce load/use pressure.
+ "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
+ "udot v17.4s, v3.16b, v0.b[1]\n" // out of sequence - used to reduce load/use pressure.
+ "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
+ "udot v10.4s, v2.16b, v0.b[2]\n"
+ "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
+ "udot v11.4s, v2.16b, v0.b[3]\n"
+ "ins v1.d[1], x18\n" // Finish loading v1
+ "udot v12.4s, v2.16b, v1.b[0]\n"
+ "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
+ "udot v13.4s, v2.16b, v1.b[1]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
+ "udot v14.4s, v2.16b, v1.b[2]\n"
+
+ "udot v15.4s, v2.16b, v1.b[3]\n"
+ "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
+ "udot v18.4s, v3.16b, v0.b[2]\n"
+ "ins v4.d[1], x18\n" // Finish loading v4
+ "udot v19.4s, v3.16b, v0.b[3]\n"
+ "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
+ "udot v20.4s, v3.16b, v1.b[0]\n"
+ "subs %w[depth], %w[depth], #4\n"
+ "udot v21.4s, v3.16b, v1.b[1]\n"
+
+ "udot v22.4s, v3.16b, v1.b[2]\n"
+
+ "udot v23.4s, v3.16b, v1.b[3]\n"
+ "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
+ "udot v24.4s, v4.16b, v0.b[0]\n"
+ "ins v2.d[1], x18\n" // Finish loading next v2
+ "udot v25.4s, v4.16b, v0.b[1]\n"
+ "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
+ "udot v26.4s, v4.16b, v0.b[2]\n"
+
+ "udot v27.4s, v4.16b, v0.b[3]\n"
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
+ "udot v28.4s, v4.16b, v1.b[0]\n"
+ "ins v3.d[1], x18\n" // Finish loading next v3
+ "udot v29.4s, v4.16b, v1.b[1]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
+ "udot v30.4s, v4.16b, v1.b[2]\n"
+
+ "udot v31.4s, v4.16b, v1.b[3]\n"
+ "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators
+ "mov x0, %[accum_ptr]\n"
+ "st1 {v8.4s}, [x0], #16\n"
+ "st1 {v16.4s}, [x0], #16\n"
+ "st1 {v24.4s}, [x0], #16\n"
+ "st1 {v9.4s}, [x0], #16\n"
+ "st1 {v17.4s}, [x0], #16\n"
+ "st1 {v25.4s}, [x0], #16\n"
+ "st1 {v10.4s}, [x0], #16\n"
+ "st1 {v18.4s}, [x0], #16\n"
+ "st1 {v26.4s}, [x0], #16\n"
+ "st1 {v11.4s}, [x0], #16\n"
+ "st1 {v19.4s}, [x0], #16\n"
+ "st1 {v27.4s}, [x0], #16\n"
+ "st1 {v12.4s}, [x0], #16\n"
+ "st1 {v20.4s}, [x0], #16\n"
+ "st1 {v28.4s}, [x0], #16\n"
+ "st1 {v13.4s}, [x0], #16\n"
+ "st1 {v21.4s}, [x0], #16\n"
+ "st1 {v29.4s}, [x0], #16\n"
+ "st1 {v14.4s}, [x0], #16\n"
+ "st1 {v22.4s}, [x0], #16\n"
+ "st1 {v30.4s}, [x0], #16\n"
+ "st1 {v15.4s}, [x0], #16\n"
+ "st1 {v23.4s}, [x0], #16\n"
+ "st1 {v31.4s}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+ }
+};
+#endif // __ARM_FEATURE_DOTPROD
+
// We don't actually use int32*int32 in production. This is just an
// experiment to help dissociate the effect of integer-vs-float, from the
// effect of operands width.
@@ -3203,8 +3503,172 @@ struct NEON_64bit_GEMM_Float32_WithScalar_A53 {
};
#endif
+// Faster kernel contributed by ARM. Tuned for A55r1.
+struct NEON_64bit_GEMM_Float32_WithScalar_A55r1 {
+ typedef float OperandType;
+ typedef float AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "mov x0, %[accum_ptr]\n"
+ "ld1 {v8.4s}, [x0], #16\n"
+ "ld1 {v16.4s}, [x0], #16\n"
+ "ld1 {v24.4s}, [x0], #16\n"
+ "ld1 {v9.4s}, [x0], #16\n"
+ "ld1 {v17.4s}, [x0], #16\n"
+ "ld1 {v25.4s}, [x0], #16\n"
+ "ld1 {v10.4s}, [x0], #16\n"
+ "ld1 {v18.4s}, [x0], #16\n"
+ "ld1 {v26.4s}, [x0], #16\n"
+ "ld1 {v11.4s}, [x0], #16\n"
+ "ld1 {v19.4s}, [x0], #16\n"
+ "ld1 {v27.4s}, [x0], #16\n"
+ "ld1 {v12.4s}, [x0], #16\n"
+ "ld1 {v20.4s}, [x0], #16\n"
+ "ld1 {v28.4s}, [x0], #16\n"
+ "ld1 {v13.4s}, [x0], #16\n"
+ "ld1 {v21.4s}, [x0], #16\n"
+ "ld1 {v29.4s}, [x0], #16\n"
+ "ld1 {v14.4s}, [x0], #16\n"
+ "ld1 {v22.4s}, [x0], #16\n"
+ "ld1 {v30.4s}, [x0], #16\n"
+ "ld1 {v15.4s}, [x0], #16\n"
+ "ld1 {v23.4s}, [x0], #16\n"
+ "ld1 {v31.4s}, [x0], #16\n"
+
+ // A55r1 requires a hybrid of the A53 and standard approaches.
+ //
+ // Like A53, this processor prefers 64-bit loads.
+ //
+ // Unlike A53, it is capable of dual-issuing a 64-bit vector load
+ // (or INS) with a FMLA instruction.
+ //
+ // Therefore we aim to issue an FMLA instruction every cycle.
+ // Alongside three FMLAs we can dual issue a (vector) 64-bit load, a
+ // scalar 64-bit load and finally an INS to replicate the effect of
+ // a single 128-bit load.
+ //
+ // The loop contains 24 FMLA instructions, and 5 vector registers
+ // need to be loaded, consuming 15 dual issue slots. This leaves 9
+ // dual issue slots. Four of these are used for loop housekeeping
+ // (2 pointer adds, 1 counter update and 1 branch), leaving 5 left
+ // over (marked by blank lines).
+ //
+ // Choice of x18 to store the upper halves on their way into the
+ // vector registers is arbitrary. Added to the clobber list so that
+ // the compiler will make it available.
+
+
+ // At the start of the loop, it is assumed that v0 is "half loaded" -
+ // bottom half in place in d0 and the upper half in x18 ready to
+ // insert. So set that up here for the first iteration:
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell
+ "ldr x18, [%[rhs_ptr], #8]\n" // Upper half
+
+ // v2-v3 should be fully loaded - as it's outside the loop proper it's fine
+ // to use a 128-bit load here.
+ "ldr q2, [%[lhs_ptr]]\n" // first Lhs cell
+ "ldr q3, [%[lhs_ptr], #16]\n" // second Lhs cell
+
+ GEMMLOWP_LABEL_LOOP
+ ":\n"
+
+ "fmla v8.4s, v2.4s, v0.s[0]\n"
+ "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1
+ "fmla v9.4s, v2.4s, v0.s[1]\n"
+ "ins v0.d[1], x18\n" // Finish loading v0
+ "fmla v16.4s, v3.4s, v0.s[0]\n" // out of sequence - used to reduce load/use pressure.
+ "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register
+ "fmla v17.4s, v3.4s, v0.s[1]\n" // out of sequence - used to reduce load/use pressure.
+ "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer.
+ "fmla v10.4s, v2.4s, v0.s[2]\n"
+ "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4
+ "fmla v11.4s, v2.4s, v0.s[3]\n"
+ "ins v1.d[1], x18\n" // Finish loading v1
+ "fmla v12.4s, v2.4s, v1.s[0]\n"
+ "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register
+ "fmla v13.4s, v2.4s, v1.s[1]\n"
+ "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer.
+ "fmla v14.4s, v2.4s, v1.s[2]\n"
+
+ "fmla v15.4s, v2.4s, v1.s[3]\n"
+ "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time)
+ "fmla v18.4s, v3.4s, v0.s[2]\n"
+ "ins v4.d[1], x18\n" // Finish loading v4
+ "fmla v19.4s, v3.4s, v0.s[3]\n"
+ "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register
+ "fmla v20.4s, v3.4s, v1.s[0]\n"
+ "subs %w[depth], %w[depth], #1\n"
+ "fmla v21.4s, v3.4s, v1.s[1]\n"
+
+ "fmla v22.4s, v3.4s, v1.s[2]\n"
+
+ "fmla v23.4s, v3.4s, v1.s[3]\n"
+ "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time)
+ "fmla v24.4s, v4.4s, v0.s[0]\n"
+ "ins v2.d[1], x18\n" // Finish loading next v2
+ "fmla v25.4s, v4.4s, v0.s[1]\n"
+ "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register
+ "fmla v26.4s, v4.4s, v0.s[2]\n"
+
+ "fmla v27.4s, v4.4s, v0.s[3]\n"
+ "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time)
+ "fmla v28.4s, v4.4s, v1.s[0]\n"
+ "ins v3.d[1], x18\n" // Finish loading next v3
+ "fmla v29.4s, v4.4s, v1.s[1]\n"
+ "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register
+ "fmla v30.4s, v4.4s, v1.s[2]\n"
+
+ "fmla v31.4s, v4.4s, v1.s[3]\n"
+ "bne " GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators
+ "mov x0, %[accum_ptr]\n"
+ "st1 {v8.4s}, [x0], #16\n"
+ "st1 {v16.4s}, [x0], #16\n"
+ "st1 {v24.4s}, [x0], #16\n"
+ "st1 {v9.4s}, [x0], #16\n"
+ "st1 {v17.4s}, [x0], #16\n"
+ "st1 {v25.4s}, [x0], #16\n"
+ "st1 {v10.4s}, [x0], #16\n"
+ "st1 {v18.4s}, [x0], #16\n"
+ "st1 {v26.4s}, [x0], #16\n"
+ "st1 {v11.4s}, [x0], #16\n"
+ "st1 {v19.4s}, [x0], #16\n"
+ "st1 {v27.4s}, [x0], #16\n"
+ "st1 {v12.4s}, [x0], #16\n"
+ "st1 {v20.4s}, [x0], #16\n"
+ "st1 {v28.4s}, [x0], #16\n"
+ "st1 {v13.4s}, [x0], #16\n"
+ "st1 {v21.4s}, [x0], #16\n"
+ "st1 {v29.4s}, [x0], #16\n"
+ "st1 {v14.4s}, [x0], #16\n"
+ "st1 {v22.4s}, [x0], #16\n"
+ "st1 {v30.4s}, [x0], #16\n"
+ "st1 {v15.4s}, [x0], #16\n"
+ "st1 {v23.4s}, [x0], #16\n"
+ "st1 {v31.4s}, [x0], #16\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31");
+ }
+};
+
#endif // __aarch64__
+#if defined(__arm__) || defined(__aarch64__)
#ifndef __aarch64__
inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
const int32x2_t c = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
@@ -3388,6 +3852,974 @@ using NEON_32bit_GEMM_Float32_WithScalar_intrinsics =
using NEON_64bit_GEMM_Float32_WithScalar_intrinsics =
NEON_GEMM_Float32_WithScalar_intrinsics<2>;
+#endif // __arm__ || __aarch64__
+
+#ifdef __mips
+static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) {
+ // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c).
+#if 0
+ return __builtin_msa_maddv_w(a, b, c);
+#else
+ asm volatile("maddv.w %w[a], %w[b], %w[c]\n"
+ // Outputs
+ : [a] "+f"(a)
+ // Inputs
+ : [b] "f"(b), [c] "f"(c));
+ return a;
+#endif
+}
+
+// Using 32x32=32 multiplications.
+// 20 MSA regs used:
+// - 12 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~55 instructions in the loop.
+struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics {
+ typedef std::uint8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ const v16i8 zeroes = __builtin_msa_ldi_b(0);
+ v4i32 acc[3][4];
+ // Load accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+
+ while (depth > 0) {
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ v8i16 lhs[6];
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
+ lhs[1] =
+ reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[0])));
+ lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+ lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+ lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+
+ // Depth 0.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ }
+ }
+
+ // Depth 1.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ }
+ }
+
+ lhs_ptr += 24;
+ rhs_ptr += 8;
+ depth -= 2;
+ }
+
+ // Store accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics.
+// Using 32x32=32 multiplications.
+// 20 MSA regs used:
+// - 12 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~55 instructions in the loop.
+struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly {
+ typedef std::uint8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w19, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A half of the 2x4 cell of Rhs is stored in 32bit in w18.
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w12-w17.
+ // A 12x4 block of accumulators is stored in 32bit in w0-w11.
+ //
+ // +------+------+------+------+
+ // Rhs |w18[0]|w18[1]|w18[2]|w18[3]|
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+---+ - - - - +------+------+------+------+
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // |w12|w15| | w0 | w1 | w2 | w3 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // |w13|w16| | w4 | w5 | w6 | w7 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // |w14|w17| | w8 | w9 | w10 | w11 |
+ // +---+---+ - - - - +------+------+------+------+
+ //
+ // Accumulator
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w12, 0(%[lhs_ptr])\n"
+ "ld.b $w13, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w12, $w19, $w12\n"
+ "ilvl.b $w14, $w19, $w13\n"
+ "ilvr.b $w13, $w19, $w13\n"
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ "ilvl.h $w15, $w19, $w12\n"
+ "ilvl.h $w16, $w19, $w13\n"
+ "ilvl.h $w17, $w19, $w14\n"
+ "ilvr.h $w12, $w19, $w12\n"
+ "ilvr.h $w13, $w19, $w13\n"
+ "ilvr.h $w14, $w19, $w14\n"
+
+ // Depth 0.
+ "fill.w $w18, $a0\n"
+ "lbu $a0, 4(%[rhs_ptr])\n"
+ "maddv.w $w0, $w12, $w18\n"
+ "maddv.w $w4, $w13, $w18\n"
+ "maddv.w $w8, $w14, $w18\n"
+ "fill.w $w18, $a1\n"
+ "lbu $a1, 5(%[rhs_ptr])\n"
+ "maddv.w $w1, $w12, $w18\n"
+ "maddv.w $w5, $w13, $w18\n"
+ "maddv.w $w9, $w14, $w18\n"
+ "fill.w $w18, $a2\n"
+ "lbu $a2, 6(%[rhs_ptr])\n"
+ "maddv.w $w2, $w12, $w18\n"
+ "maddv.w $w6, $w13, $w18\n"
+ "maddv.w $w10, $w14, $w18\n"
+ "fill.w $w18, $a3\n"
+ "lbu $a3, 7(%[rhs_ptr])\n"
+ "maddv.w $w3, $w12, $w18\n"
+ "maddv.w $w7, $w13, $w18\n"
+ "maddv.w $w11, $w14, $w18\n"
+
+ // Depth 1.
+ "fill.w $w18, $a0\n"
+ "maddv.w $w0, $w15, $w18\n"
+ "maddv.w $w4, $w16, $w18\n"
+ "maddv.w $w8, $w17, $w18\n"
+ "fill.w $w18, $a1\n"
+ "maddv.w $w1, $w15, $w18\n"
+ "maddv.w $w5, $w16, $w18\n"
+ "maddv.w $w9, $w17, $w18\n"
+ "fill.w $w18, $a2\n"
+ "maddv.w $w2, $w15, $w18\n"
+ "maddv.w $w6, $w16, $w18\n"
+ "maddv.w $w10, $w17, $w18\n"
+ "fill.w $w18, $a3\n"
+ "maddv.w $w3, $w15, $w18\n"
+ "maddv.w $w7, $w16, $w18\n"
+ "maddv.w $w11, $w17, $w18\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "a0", "a1", "a2", "a3",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19");
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
+// Using 16x16=32 multiplications.
+// 20 MSA regs used:
+// - 12 accumulators
+// - 3 lhs
+// - 4 rhs
+// - 1 temps/zeroes
+// ~45 instructions in the loop.
+struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2 {
+ typedef std::uint8_t OperandType;
+ typedef std::int32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w19, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A 2x4 cell of Rhs is stored in 16bit in w15-w18 (each register
+ // contains 4 replicas of a pair of elements).
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w12-w14.
+ // A 12x4 block of accumulators is stored in 32bit in w0-w11.
+ //
+ // +-----+-----+-----+-----+
+ // Rhs | w15 | w16 | w17 | w18 |
+ // +-----+-----+-----+-----+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w12| | w0 | w1 | w2 | w3 |
+ // |w12| | w0 | w1 | w2 | w3 |
+ // |w12| | w0 | w1 | w2 | w3 |
+ // |w12| | w0 | w1 | w2 | w3 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w13| | w4 | w5 | w6 | w7 |
+ // |w13| | w4 | w5 | w6 | w7 |
+ // |w13| | w4 | w5 | w6 | w7 |
+ // |w13| | w4 | w5 | w6 | w7 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ // |w14| | w8 | w9 | w10 | w11 |
+ // |w14| | w8 | w9 | w10 | w11 |
+ // |w14| | w8 | w9 | w10 | w11 |
+ // |w14| | w8 | w9 | w10 | w11 |
+ // +---+ - - - - +-----+-----+-----+-----+
+ //
+ // Accumulators
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w12, 0(%[lhs_ptr])\n"
+ "ld.b $w13, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w12, $w19, $w12\n"
+ "ilvl.b $w14, $w19, $w13\n"
+ "ilvr.b $w13, $w19, $w13\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w15, $w19, $w12\n"
+ "ilvl.d $w16, $w19, $w13\n"
+ "ilvl.d $w17, $w19, $w14\n"
+ "ilvr.h $w12, $w15, $w12\n"
+ "ilvr.h $w13, $w16, $w13\n"
+ "ilvr.h $w14, $w17, $w14\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w.
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w15, $a0\n"
+ "fill.w $w16, $a1\n"
+ "fill.w $w17, $a2\n"
+ "fill.w $w18, $a3\n"
+
+ // Depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w12, $w15\n"
+ "dpadd_u.w $w4, $w13, $w15\n"
+ "dpadd_u.w $w8, $w14, $w15\n"
+ "dpadd_u.w $w1, $w12, $w16\n"
+ "dpadd_u.w $w5, $w13, $w16\n"
+ "dpadd_u.w $w9, $w14, $w16\n"
+ "dpadd_u.w $w2, $w12, $w17\n"
+ "dpadd_u.w $w6, $w13, $w17\n"
+ "dpadd_u.w $w10, $w14, $w17\n"
+ "dpadd_u.w $w3, $w12, $w18\n"
+ "dpadd_u.w $w7, $w13, $w18\n"
+ "dpadd_u.w $w11, $w14, $w18\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "v0", "v1",
+ "a0", "a1", "a2", "a3",
+ "t8", "t9",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19");
+ }
+};
+
+// Using 32x32=32 multiplications.
+// 32 MSA regs used:
+// - 24 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~95 instructions in the loop.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ const v16i8 zeroes = __builtin_msa_ldi_b(0);
+ v4i32 acc[3][8];
+ // Load accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 8; j++) {
+ acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+
+ while (depth > 0) {
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ v8i16 lhs[6];
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0));
+ lhs[1] =
+ reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0));
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[0])));
+ lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+ lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes,
+ reinterpret_cast<v16i8>(lhs[1])));
+
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+ lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]);
+ lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]);
+ lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]);
+
+ // Depth 0.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ }
+ }
+ for (int j = 4; j < 8; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs);
+ }
+ }
+
+ // Depth 1.
+ for (int j = 0; j < 4; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ }
+ }
+ for (int j = 4; j < 8; j++) {
+ // Load 1 byte of rhs, making 4 32-bit replicas of it.
+ v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 8]));
+ // Multiply-add into accumulators.
+ for (int i = 0; i < 3; i++) {
+ acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs);
+ }
+ }
+
+ lhs_ptr += 24;
+ rhs_ptr += 16;
+ depth -= 2;
+ }
+
+ // Store accumulators.
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 8; j++) {
+ __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0);
+ }
+ }
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics.
+// Using 32x32=32 multiplications.
+// 32 MSA regs used:
+// - 24 accumulators
+// - 6 lhs
+// - 1 rhs
+// - 1 temps/zeroes
+// ~95 instructions in the loop.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ "ld.w $w12, (12*16)(%[accum_ptr])\n"
+ "ld.w $w16, (13*16)(%[accum_ptr])\n"
+ "ld.w $w20, (14*16)(%[accum_ptr])\n"
+ "ld.w $w13, (15*16)(%[accum_ptr])\n"
+ "ld.w $w17, (16*16)(%[accum_ptr])\n"
+ "ld.w $w21, (17*16)(%[accum_ptr])\n"
+ "ld.w $w14, (18*16)(%[accum_ptr])\n"
+ "ld.w $w18, (19*16)(%[accum_ptr])\n"
+ "ld.w $w22, (20*16)(%[accum_ptr])\n"
+ "ld.w $w15, (21*16)(%[accum_ptr])\n"
+ "ld.w $w19, (22*16)(%[accum_ptr])\n"
+ "ld.w $w23, (23*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w31, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A quarter of the 2 2x4 cells of Rhs is stored in 32bit in w30.
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w24-w29.
+ // A 12x8 block of accumulators is stored in 32bit in w0-w23.
+ //
+ // +------+------+------+------+
+ // Rhs |w30[0]|w30[1]|w30[2]|w30[3]|
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+---+ - - - - +------+------+------+------+
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // +---+---+ - - - - +------+------+------+------+
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23|
+ // +---+---+ - - - - +------+------+------+------+
+ //
+ // Accumulator
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w24, 0(%[lhs_ptr])\n"
+ "ld.b $w25, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for the first half of depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w24, $w31, $w24\n"
+ "ilvl.b $w26, $w31, $w25\n"
+ "ilvr.b $w25, $w31, $w25\n"
+ // Zero-extend 16-bit elements of lhs[] to 32 bits.
+ "ilvl.h $w27, $w31, $w24\n"
+ "ilvl.h $w28, $w31, $w25\n"
+ "ilvl.h $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w31, $w24\n"
+ "ilvr.h $w25, $w31, $w25\n"
+ "ilvr.h $w26, $w31, $w26\n"
+
+ // Depth 0.
+ "fill.w $w30, $a0\n"
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "maddv.w $w0, $w24, $w30\n"
+ "maddv.w $w4, $w25, $w30\n"
+ "maddv.w $w8, $w26, $w30\n"
+ "fill.w $w30, $a1\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "maddv.w $w1, $w24, $w30\n"
+ "maddv.w $w5, $w25, $w30\n"
+ "maddv.w $w9, $w26, $w30\n"
+ "fill.w $w30, $a2\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "maddv.w $w2, $w24, $w30\n"
+ "maddv.w $w6, $w25, $w30\n"
+ "maddv.w $w10, $w26, $w30\n"
+ "fill.w $w30, $a3\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ "maddv.w $w3, $w24, $w30\n"
+ "maddv.w $w7, $w25, $w30\n"
+ "maddv.w $w11, $w26, $w30\n"
+
+ "fill.w $w30, $a0\n"
+ "lbu $a0, 4(%[rhs_ptr])\n"
+ "maddv.w $w12, $w24, $w30\n"
+ "maddv.w $w16, $w25, $w30\n"
+ "maddv.w $w20, $w26, $w30\n"
+ "fill.w $w30, $a1\n"
+ "lbu $a1, 5(%[rhs_ptr])\n"
+ "maddv.w $w13, $w24, $w30\n"
+ "maddv.w $w17, $w25, $w30\n"
+ "maddv.w $w21, $w26, $w30\n"
+ "fill.w $w30, $a2\n"
+ "lbu $a2, 6(%[rhs_ptr])\n"
+ "maddv.w $w14, $w24, $w30\n"
+ "maddv.w $w18, $w25, $w30\n"
+ "maddv.w $w22, $w26, $w30\n"
+ "fill.w $w30, $a3\n"
+ "lbu $a3, 7(%[rhs_ptr])\n"
+ "maddv.w $w15, $w24, $w30\n"
+ "maddv.w $w19, $w25, $w30\n"
+ "maddv.w $w23, $w26, $w30\n"
+
+ // Depth 1.
+ "fill.w $w30, $a0\n"
+ "lbu $a0, 12(%[rhs_ptr])\n"
+ "maddv.w $w0, $w27, $w30\n"
+ "maddv.w $w4, $w28, $w30\n"
+ "maddv.w $w8, $w29, $w30\n"
+ "fill.w $w30, $a1\n"
+ "lbu $a1, 13(%[rhs_ptr])\n"
+ "maddv.w $w1, $w27, $w30\n"
+ "maddv.w $w5, $w28, $w30\n"
+ "maddv.w $w9, $w29, $w30\n"
+ "fill.w $w30, $a2\n"
+ "lbu $a2, 14(%[rhs_ptr])\n"
+ "maddv.w $w2, $w27, $w30\n"
+ "maddv.w $w6, $w28, $w30\n"
+ "maddv.w $w10, $w29, $w30\n"
+ "fill.w $w30, $a3\n"
+ "lbu $a3, 15(%[rhs_ptr])\n"
+ "maddv.w $w3, $w27, $w30\n"
+ "maddv.w $w7, $w28, $w30\n"
+ "maddv.w $w11, $w29, $w30\n"
+
+ "fill.w $w30, $a0\n"
+ "maddv.w $w12, $w27, $w30\n"
+ "maddv.w $w16, $w28, $w30\n"
+ "maddv.w $w20, $w29, $w30\n"
+ "fill.w $w30, $a1\n"
+ "maddv.w $w13, $w27, $w30\n"
+ "maddv.w $w17, $w28, $w30\n"
+ "maddv.w $w21, $w29, $w30\n"
+ "fill.w $w30, $a2\n"
+ "maddv.w $w14, $w27, $w30\n"
+ "maddv.w $w18, $w28, $w30\n"
+ "maddv.w $w22, $w29, $w30\n"
+ "fill.w $w30, $a3\n"
+ "maddv.w $w15, $w27, $w30\n"
+ "maddv.w $w19, $w28, $w30\n"
+ "maddv.w $w23, $w29, $w30\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ "st.w $w12, (12*16)(%[accum_ptr])\n"
+ "st.w $w16, (13*16)(%[accum_ptr])\n"
+ "st.w $w20, (14*16)(%[accum_ptr])\n"
+ "st.w $w13, (15*16)(%[accum_ptr])\n"
+ "st.w $w17, (16*16)(%[accum_ptr])\n"
+ "st.w $w21, (17*16)(%[accum_ptr])\n"
+ "st.w $w14, (18*16)(%[accum_ptr])\n"
+ "st.w $w18, (19*16)(%[accum_ptr])\n"
+ "st.w $w22, (20*16)(%[accum_ptr])\n"
+ "st.w $w15, (21*16)(%[accum_ptr])\n"
+ "st.w $w19, (22*16)(%[accum_ptr])\n"
+ "st.w $w23, (23*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "a0", "a1", "a2", "a3",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
+ "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
+ }
+};
+
+// Assembly implementation of the above
+// MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO).
+// Using 16x16=32 multiplications.
+// 32 MSA regs used:
+// - 24 accumulators
+// - 3 lhs
+// - 4 rhs
+// - 1 temps/zeroes
+// ~70 instructions in the loop.
+struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 {
+ typedef std::uint8_t OperandType;
+ typedef std::uint32_t AccumulatorType;
+ typedef KernelFormat<
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>,
+ KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> >
+ Format;
+ static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr,
+ AccumulatorType* accum_ptr, int depth) {
+ asm volatile(
+ // Load accumulators
+ "ld.w $w0, (0*16)(%[accum_ptr])\n"
+ "ld.w $w4, (1*16)(%[accum_ptr])\n"
+ "ld.w $w8, (2*16)(%[accum_ptr])\n"
+ "ld.w $w1, (3*16)(%[accum_ptr])\n"
+ "ld.w $w5, (4*16)(%[accum_ptr])\n"
+ "ld.w $w9, (5*16)(%[accum_ptr])\n"
+ "ld.w $w2, (6*16)(%[accum_ptr])\n"
+ "ld.w $w6, (7*16)(%[accum_ptr])\n"
+ "ld.w $w10, (8*16)(%[accum_ptr])\n"
+ "ld.w $w3, (9*16)(%[accum_ptr])\n"
+ "ld.w $w7, (10*16)(%[accum_ptr])\n"
+ "ld.w $w11, (11*16)(%[accum_ptr])\n"
+ "ld.w $w12, (12*16)(%[accum_ptr])\n"
+ "ld.w $w16, (13*16)(%[accum_ptr])\n"
+ "ld.w $w20, (14*16)(%[accum_ptr])\n"
+ "ld.w $w13, (15*16)(%[accum_ptr])\n"
+ "ld.w $w17, (16*16)(%[accum_ptr])\n"
+ "ld.w $w21, (17*16)(%[accum_ptr])\n"
+ "ld.w $w14, (18*16)(%[accum_ptr])\n"
+ "ld.w $w18, (19*16)(%[accum_ptr])\n"
+ "ld.w $w22, (20*16)(%[accum_ptr])\n"
+ "ld.w $w15, (21*16)(%[accum_ptr])\n"
+ "ld.w $w19, (22*16)(%[accum_ptr])\n"
+ "ld.w $w23, (23*16)(%[accum_ptr])\n"
+ // Set a temp to all zeroes.
+ "ldi.b $w31, 0\n"
+
+ GEMMLOWP_LABEL_LOOP ":\n"
+ // Overview of register layout:
+ //
+ // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30
+ // (each register contains 4 replicas of a pair of elements).
+ // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26.
+ // A 12x8 block of accumulators is stored in 32bit in w0-w23.
+ //
+ // +------+------+------+------+
+ // Rhs |w27 |w28 |w29 |w30 |
+ // +------+------+------+------+
+ //
+ // | | | | |
+ //
+ // Lhs | | | | |
+ //
+ // +---+ - - - - +------+------+------+------+
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // |w24| |w0/12 |w1/13 |w2/14 |w3/15 |
+ // +---+ - - - - +------+------+------+------+
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // |w25| |w4/16 |w5/17 |w6/18 |w7/19 |
+ // +---+ - - - - +------+------+------+------+
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // |w26| |w8/20 |w9/21 |w10/22|w11/23|
+ // +---+ - - - - +------+------+------+------+
+ //
+ // Accumulators
+
+ // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads.
+ "ld.b $w24, 0(%[lhs_ptr])\n"
+ "ld.b $w25, 8(%[lhs_ptr])\n"
+
+ // Load 4 bytes of rhs[] for the first half of depth 0.
+ "lbu $a0, 0(%[rhs_ptr])\n"
+ "lbu $a1, 1(%[rhs_ptr])\n"
+ "lbu $a2, 2(%[rhs_ptr])\n"
+ "lbu $a3, 3(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the first half of depth 1.
+ "lbu $v0, 4(%[rhs_ptr])\n"
+ "lbu $v1, 5(%[rhs_ptr])\n"
+ "lbu $t8, 6(%[rhs_ptr])\n"
+ "lbu $t9, 7(%[rhs_ptr])\n"
+
+ // Zero-extend 8-bit elements of lhs[] to 16 bits.
+ "ilvr.b $w24, $w31, $w24\n"
+ "ilvl.b $w26, $w31, $w25\n"
+ "ilvr.b $w25, $w31, $w25\n"
+ // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w.
+ "ilvl.d $w27, $w31, $w24\n"
+ "ilvl.d $w28, $w31, $w25\n"
+ "ilvl.d $w29, $w31, $w26\n"
+ "ilvr.h $w24, $w27, $w24\n"
+ "ilvr.h $w25, $w28, $w25\n"
+ "ilvr.h $w26, $w29, $w26\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
+ // (for the first half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Load 4 bytes of rhs[] for the second half of depth 0.
+ "lbu $a0, 8(%[rhs_ptr])\n"
+ "lbu $a1, 9(%[rhs_ptr])\n"
+ "lbu $a2, 10(%[rhs_ptr])\n"
+ "lbu $a3, 11(%[rhs_ptr])\n"
+ // Load 4 bytes of rhs[] for the second half of depth 1.
+ "lbu $v0, 12(%[rhs_ptr])\n"
+ "lbu $v1, 13(%[rhs_ptr])\n"
+ "lbu $t8, 14(%[rhs_ptr])\n"
+ "lbu $t9, 15(%[rhs_ptr])\n"
+
+ // First half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w0, $w24, $w27\n"
+ "dpadd_u.w $w4, $w25, $w27\n"
+ "dpadd_u.w $w8, $w26, $w27\n"
+ "dpadd_u.w $w1, $w24, $w28\n"
+ "dpadd_u.w $w5, $w25, $w28\n"
+ "dpadd_u.w $w9, $w26, $w28\n"
+ "dpadd_u.w $w2, $w24, $w29\n"
+ "dpadd_u.w $w6, $w25, $w29\n"
+ "dpadd_u.w $w10, $w26, $w29\n"
+ "dpadd_u.w $w3, $w24, $w30\n"
+ "dpadd_u.w $w7, $w25, $w30\n"
+ "dpadd_u.w $w11, $w26, $w30\n"
+
+ // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w
+ // (for the second half).
+ "ins $a0, $v0, 16, 8\n"
+ "ins $a1, $v1, 16, 8\n"
+ "ins $a2, $t8, 16, 8\n"
+ "ins $a3, $t9, 16, 8\n"
+ // Make 4 replicas of every pair of rhs[] elements.
+ "fill.w $w27, $a0\n"
+ "fill.w $w28, $a1\n"
+ "fill.w $w29, $a2\n"
+ "fill.w $w30, $a3\n"
+
+ // Second half of depths 0 and 1.
+ // Dot-product-(and)-add doubles multiplicand width.
+ "dpadd_u.w $w12, $w24, $w27\n"
+ "dpadd_u.w $w16, $w25, $w27\n"
+ "dpadd_u.w $w20, $w26, $w27\n"
+ "dpadd_u.w $w13, $w24, $w28\n"
+ "dpadd_u.w $w17, $w25, $w28\n"
+ "dpadd_u.w $w21, $w26, $w28\n"
+ "dpadd_u.w $w14, $w24, $w29\n"
+ "dpadd_u.w $w18, $w25, $w29\n"
+ "dpadd_u.w $w22, $w26, $w29\n"
+ "dpadd_u.w $w15, $w24, $w30\n"
+ "dpadd_u.w $w19, $w25, $w30\n"
+ "dpadd_u.w $w23, $w26, $w30\n"
+
+ "addiu %[depth], -2\n"
+ GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n"
+ GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n"
+ "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n"
+
+ // Store accumulators.
+ "st.w $w0, (0*16)(%[accum_ptr])\n"
+ "st.w $w4, (1*16)(%[accum_ptr])\n"
+ "st.w $w8, (2*16)(%[accum_ptr])\n"
+ "st.w $w1, (3*16)(%[accum_ptr])\n"
+ "st.w $w5, (4*16)(%[accum_ptr])\n"
+ "st.w $w9, (5*16)(%[accum_ptr])\n"
+ "st.w $w2, (6*16)(%[accum_ptr])\n"
+ "st.w $w6, (7*16)(%[accum_ptr])\n"
+ "st.w $w10, (8*16)(%[accum_ptr])\n"
+ "st.w $w3, (9*16)(%[accum_ptr])\n"
+ "st.w $w7, (10*16)(%[accum_ptr])\n"
+ "st.w $w11, (11*16)(%[accum_ptr])\n"
+ "st.w $w12, (12*16)(%[accum_ptr])\n"
+ "st.w $w16, (13*16)(%[accum_ptr])\n"
+ "st.w $w20, (14*16)(%[accum_ptr])\n"
+ "st.w $w13, (15*16)(%[accum_ptr])\n"
+ "st.w $w17, (16*16)(%[accum_ptr])\n"
+ "st.w $w21, (17*16)(%[accum_ptr])\n"
+ "st.w $w14, (18*16)(%[accum_ptr])\n"
+ "st.w $w18, (19*16)(%[accum_ptr])\n"
+ "st.w $w22, (20*16)(%[accum_ptr])\n"
+ "st.w $w15, (21*16)(%[accum_ptr])\n"
+ "st.w $w19, (22*16)(%[accum_ptr])\n"
+ "st.w $w23, (23*16)(%[accum_ptr])\n"
+ : // outputs
+ [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
+ [depth] "+r"(depth)
+ : // inputs
+ [accum_ptr] "r"(accum_ptr)
+ : // clobbers
+ "memory",
+ "v0", "v1",
+ "a0", "a1", "a2", "a3",
+ "t8", "t9",
+ "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7",
+ "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15",
+ "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23",
+ "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31");
+ }
+};
+#endif // __mips
// BEGIN code copied from gemmlowp/internal/kernel_reference.h
@@ -3451,8 +4883,9 @@ class CacheLineAlignedBuffer {
data_ = nullptr;
// Adds a few bytes of padding here, because the 64-bit 'A57' kernel
// reads one iteration past the end the buffer, causing a crash on iOS.
- posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize,
- size_ * sizeof(DataType) + 16);
+ int res = posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize,
+ size_ * sizeof(DataType) + 16);
+ (void)res;
}
~CacheLineAlignedBuffer() { free(data_); }
@@ -3460,7 +4893,7 @@ class CacheLineAlignedBuffer {
const DataType* data() const { return data_; }
DataType* data() { return data_; }
- const std::size_t size() const { return size_; }
+ std::size_t size() const { return size_; }
private:
const std::size_t size_;
@@ -3726,12 +5159,15 @@ int main() {
#endif
#ifdef __aarch64__
-
BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits);
BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics);
BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57);
+#ifdef __ARM_FEATURE_DOTPROD
+ BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct);
+ BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1);
+#endif
BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar);
BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar);
BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar);
@@ -3740,6 +5176,16 @@ int main() {
#ifndef __APPLE__
BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A53);
#endif
+ BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A55r1);
+#endif
+
+#ifdef __mips
+ BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics);
+ BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly);
+ BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly);
+ BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2);
#endif
return 0;