diff options
author | Miao Wang <miaowang@google.com> | 2017-07-26 08:35:51 +0000 |
---|---|---|
committer | android-build-merger <android-build-merger@google.com> | 2017-07-26 08:35:51 +0000 |
commit | df42d0b1455deec446c3d840647d283da03b6177 (patch) | |
tree | 08f659fe1f92d26d276ef886f9a16a92ac51d344 | |
parent | 9f9f0a7fb40cc12913c5da5015367e772401b3b6 (diff) | |
parent | 7ca694a42d6b0794c9ec4b513fe1abe20eb09e2c (diff) | |
download | gemmlowp-df42d0b1455deec446c3d840647d283da03b6177.tar.gz |
Make gemmlowp properly support host modules.
am: 7ca694a42d
Change-Id: Ife77d2f581adf742c7de1591aeda093f00543702
-rw-r--r-- | Android.bp | 12 | ||||
-rw-r--r-- | internal/allocator.h | 2 | ||||
-rw-r--r-- | internal/common.h | 2 | ||||
-rw-r--r-- | internal/fixedpoint.h | 552 | ||||
-rw-r--r-- | internal/fixedpoint_neon.h | 165 | ||||
-rw-r--r-- | internal/iterator.h | 78 | ||||
-rw-r--r-- | internal/kernel_SSE.h | 512 | ||||
-rw-r--r-- | internal/pack_SSE.h | 169 | ||||
-rw-r--r-- | internal/unpack_neon.h | 212 |
9 files changed, 14 insertions, 1690 deletions
@@ -12,4 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +cc_library_headers { + name: "libfixedpoint", + host_supported: true, + export_include_dirs: ["fixedpoint"], +} + +cc_library_headers { + name: "libgemmlowp", + host_supported: true, + export_include_dirs: ["public"], +} + subdirs = ["eight_bit_int_gemm"] diff --git a/internal/allocator.h b/internal/allocator.h index 0fe4a01..da325a4 100644 --- a/internal/allocator.h +++ b/internal/allocator.h @@ -41,7 +41,7 @@ #include "common.h" -#if defined ANDROID || defined __ANDROID__ +#if defined(__ANDROID__) #include <android/api-level.h> // The 18 here should be 16, but has to be 18 for now due // to a Google-internal issue. diff --git a/internal/common.h b/internal/common.h index 1d89b26..511809d 100644 --- a/internal/common.h +++ b/internal/common.h @@ -117,7 +117,7 @@ // Detect Android. Don't conflate with ARM - we care about tuning // for non-ARM Android devices too. This can be used in conjunction // with x86 to tune differently for mobile x86 CPUs (Atom) vs. desktop x86 CPUs. -#if defined(__ANDROID__) || defined(ANDROID) +#if defined(__ANDROID__) #define GEMMLOWP_ANDROID #endif diff --git a/internal/fixedpoint.h b/internal/fixedpoint.h deleted file mode 100644 index 340331b..0000000 --- a/internal/fixedpoint.h +++ /dev/null @@ -1,552 +0,0 @@ -// Copyright 2015 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// fixedpoint.h: fixed-point arithmetic, with basic operations and -// a few math functions such as tanh. - -// This is only used in output.h -// for some specific output pipeline stages (tanh); most of gemmlowp -// uses only plain integer arithmetic, not fixed-point arithmetic. -// At the most basic level, we distinguish between plain integer -// arithmetic and fixed-point arithmetic by the type of multiplication -// that is used: plain integer arithmetic uses plain (overflowing) -// integer multiplication, whereas fixed-point arithmetic uses -// "multiply-high" instructions, which means using only the most -// significant bits of the product, or equivalently, multiplying -// fixed-point numbers in the [-1 .. +1] interval. - -#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ -#define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ - -#include "common.h" - -#include <limits> -#include <cassert> - -namespace gemmlowp { - -template <typename tIntegerType> -tIntegerType BitAnd(tIntegerType a, tIntegerType b) { - return a & b; -} - -template <typename tIntegerType> -tIntegerType BitOr(tIntegerType a, tIntegerType b) { - return a | b; -} - -template <typename tIntegerType> -tIntegerType BitXor(tIntegerType a, tIntegerType b) { - return a ^ b; -} - -template <typename tIntegerType> -tIntegerType BitNot(tIntegerType a) { - return ~a; -} - -template <typename tIntegerType> -tIntegerType Add(tIntegerType a, tIntegerType b) { - return a + b; -} - -template <typename tIntegerType> -tIntegerType Sub(tIntegerType a, tIntegerType b) { - return a - b; -} - -template <typename tIntegerType> -tIntegerType Neg(tIntegerType a) { - return -a; -} - -template <typename tIntegerType> -tIntegerType ShiftLeft(tIntegerType a, int offset) { - return a * (1 << offset); -} - -template <typename tIntegerType> -tIntegerType ShiftRight(tIntegerType a, int offset) { - return a / (1 << offset); -} - -template <typename tIntegerType> -tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, - tIntegerType else_val) { - return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); -} - -template <typename tIntegerType> -tIntegerType MaskIfNonZero(tIntegerType a) { - static const tIntegerType zero = 0; - return a ? BitNot(zero) : zero; -} - -template <typename tIntegerType> -tIntegerType MaskIfZero(tIntegerType a) { - return MaskIfNonZero<tIntegerType>(!a); -} - -template <typename tIntegerType> -tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { - return MaskIfNonZero<tIntegerType>(a == b); -} - -template <typename tIntegerType> -tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { - return MaskIfNonZero<tIntegerType>(a != b); -} - -template <typename tIntegerType> -tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { - return MaskIfNonZero<tIntegerType>(a > b); -} - -template <typename tIntegerType> -tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { - return MaskIfNonZero<tIntegerType>(a >= b); -} - -template <typename tIntegerType> -tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { - return MaskIfNonZero<tIntegerType>(a < b); -} - -template <typename tIntegerType> -tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { - return MaskIfNonZero<tIntegerType>(a <= b); -} - -template <typename tIntegerType> -bool All(tIntegerType a) { - return a; -} - -template <typename tIntegerType> -bool Any(tIntegerType a) { - return a; -} - -template <typename IntegerType> -IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { - static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); - return a; -} - -template <> -inline int32_t RoundingHalfSum(int32_t a, int32_t b) { - int64_t a64 = a; - int64_t b64 = b; - int64_t sum = a64 + b64; - int64_t sign = sum >= 0 ? 1 : -1; - return static_cast<int32_t>((sum + sign) / 2); -} - -template <typename IntegerType> -IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { - static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); - return a; -} - -// This function implements the same computation as the ARMv7 NEON VQRDMULH -// instruction. -template <> -inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b) { - bool overflow = a == b && a == std::numeric_limits<int32_t>::min(); - int64_t a_64(a); - int64_t b_64(b); - int64_t ab_64 = a_64 * b_64; - int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); - int32_t ab_x2_high32 = static_cast<int32_t>((ab_64 + nudge) / (1ll << 31)); - return overflow ? std::numeric_limits<int32_t>::max() : ab_x2_high32; -} - -template <int Exponent, typename IntegerType, - int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> -struct ImplSaturatingRoundingMultiplyByPOT {}; - -template <int Exponent, typename IntegerType> -struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { - static IntegerType eval(IntegerType x) { return x; } -}; - -template <int Exponent> -struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, 1> { - static int32_t eval(int32_t x) { - const int64_t min = std::numeric_limits<int32_t>::min(); - const int64_t max = std::numeric_limits<int32_t>::max(); - return x >= (1 << (31 - Exponent)) ? max : x <= -(1 << (31 - Exponent)) - ? min - : x * (1 << Exponent); - } -}; - -template <int Exponent> -struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, -1> { - static int32_t eval(int32_t x) { - int32_t b = (std::abs(x) & (1 << (-Exponent - 1))) >> (-Exponent - 1); - int32_t nudge = x >= 0 ? b : -b; - return x / (1 << -Exponent) + nudge; - } -}; - -template <int Exponent, typename IntegerType> -IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { - return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); -} - -template <typename tIntegerType> -struct FixedPointRawTypeTraits {}; - -template <> -struct FixedPointRawTypeTraits<int32_t> { - typedef int32_t ScalarRawType; - static const int kLanes = 1; -}; - -template <typename tRawType> -tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { - return x; -} - -template <typename tRawType, int tIntegerBits> -class FixedPoint { - public: - typedef tRawType RawType; - - typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; - typedef typename RawTypeTraits::ScalarRawType ScalarRawType; - - static const int kTotalBits = 8 * sizeof(ScalarRawType); - static const int kIntegerBits = tIntegerBits; - static const int kFractionalBits = kTotalBits - 1 - kIntegerBits; - static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, - "bad IntegerBits"); - - typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; - - static const ScalarRawType ScalarRawMin() { - return std::numeric_limits<ScalarRawType>::min(); - } - - static const ScalarRawType ScalarRawMax() { - return std::numeric_limits<ScalarRawType>::max(); - } - - static const ScalarRawType RawMin() { - return VectorFromScalar(ScalarRawMin()); - } - - static const ScalarRawType RawMax() { - return VectorFromScalar(ScalarRawMax()); - } - - static FixedPoint FromRaw(RawType x) { - FixedPoint retval; - retval.raw() = x; - return retval; - } - - static FixedPoint FromScalarRaw(ScalarRawType x) { - FixedPoint retval; - retval.raw() = Dup<RawType>(x); - return retval; - } - - static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { - return FromScalarRaw(x.raw()); - } - - template <int Exponent> - static FixedPoint ConstantPOT() { - static const int kOffset = kFractionalBits + Exponent; - static_assert( - kOffset < 31, - "Constant not exactly representable in this fixed-point format"); - return FromScalarRaw(ScalarRawType(1) << kOffset); - } - - static FixedPoint Zero() { return FromScalarRaw(0); } - - static FixedPoint One() { - return FromScalarRaw(kIntegerBits == 0 - ? ScalarRawMax() - : (ScalarRawType(1) << kFractionalBits)); - } - - RawType raw() const { return i_; } - RawType& raw() { return i_; } - - private: - RawType i_; -}; - -template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> -FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( - FixedPoint<tRawType, tIntegerBits_a> a, - FixedPoint<tRawType, tIntegerBits_b> b) { - FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; - c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); - return c; -} - -template <int tExponent, typename tRawType, int tIntegerBits> -FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( - FixedPoint<tRawType, tIntegerBits> a) { - FixedPoint<tRawType, tExponent + tIntegerBits> c; - c.raw() = a.raw(); - return c; -} - -template <int tExponent, typename tRawType, int tIntegerBits> -FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( - FixedPoint<tRawType, tIntegerBits> a) { - return FixedPoint<tRawType, tIntegerBits>::FromRaw( - SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); -} - -#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ - template <typename tRawType, int tIntegerBits> \ - FixedPoint<tRawType, tIntegerBits> FuncName( \ - FixedPoint<tRawType, tIntegerBits> a) { \ - return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ - } - -#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ - template <typename tRawType, int tIntegerBits> \ - FixedPoint<tRawType, tIntegerBits> FuncName( \ - FixedPoint<tRawType, tIntegerBits> a, \ - FixedPoint<tRawType, tIntegerBits> b) { \ - return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ - ImplFuncName(a.raw(), b.raw())); \ - } - -MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) -MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) -MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) -MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) -MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) -MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) -MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) -MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) - -#undef MAKE_FIXEDPOINT_UNARY_FUNC -#undef MAKE_FIXEDPOINT_BINARY_FUNC - -#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ - template <typename tRawType, int tIntegerBits> \ - tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ - return FuncName(a.raw()); \ - } - -#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ - template <typename tRawType, int tIntegerBits> \ - tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ - FixedPoint<tRawType, tIntegerBits> b) { \ - return FuncName(a.raw(), b.raw()); \ - } - -MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) -MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) -MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) - -#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW -#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW - -template <typename tRawType, int tIntegerBits> -FixedPoint<tRawType, tIntegerBits> SelectUsingMask( - tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, - FixedPoint<tRawType, tIntegerBits> else_val) { - return FixedPoint<tRawType, tIntegerBits>::FromRaw( - SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); -} - -template <typename tRawType, int tIntegerBits> -bool operator==(FixedPoint<tRawType, tIntegerBits> a, - FixedPoint<tRawType, tIntegerBits> b) { - return All(MaskIfEqual(a.raw(), b.raw())); -} - -template <typename tRawType, int tIntegerBits> -bool operator!=(FixedPoint<tRawType, tIntegerBits> a, - FixedPoint<tRawType, tIntegerBits> b) { - return !(a == b); -} - -template <typename tRawType, int tIntegerBits> -double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { - static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, - "not applicable to SIMD types"); - typedef FixedPoint<tRawType, tIntegerBits> F; - return x.raw() / double(1ll << F::kFractionalBits); -} - -template <typename tRawType, int tIntegerBits> -FixedPoint<tRawType, tIntegerBits> ToFixedPoint(double x) { - typedef FixedPoint<tRawType, tIntegerBits> F; - return F::FromScalarRaw(static_cast<int32_t>( - std::min(std::max(round(x * double(1ll << F::kFractionalBits)), - double(F::ScalarRawMin())), - double(F::ScalarRawMax())))); -} - -template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> -FixedPoint<tRawType, tIntegerBitsDst> Rescale( - FixedPoint<tRawType, tIntegerBitsSrc> x) { - static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst; - FixedPoint<tRawType, tIntegerBitsDst> result; - result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); - return result; -} - -#ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS -template <typename FixedPointType> -FixedPointType CheckedFixedPointConstant( - typename FixedPointType::ScalarRawType raw_value, double double_value) { - typedef typename FixedPointType::RawType RawType; - static const int kIntegerBits = FixedPointType::kIntegerBits; - FixedPointType ref = FixedPointType::FromScalarRaw(raw_value); - FixedPointType check = ToFixedPoint<RawType, kIntegerBits>(double_value); - assert(ref == check); - return ref; -} -#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ - DoubleValue) \ - (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue)) - -#else -#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ - DoubleValue) \ - (FixedPointType::FromScalarRaw(ScalarRawValue)) -#endif - -template <typename tRawType> -FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( - FixedPoint<tRawType, 0> a) { - typedef FixedPoint<tRawType, 0> F; - const F constant_term = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); - const F constant_1_over_3 = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); - // We're evaluating a Taylor expansion around -1/8, so we do the change of - // variable: x = a + 1/8. - // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. - F x = a + F::template ConstantPOT<-3>(); - F x2 = x * x; - F x3 = x2 * x; - F x4 = x2 * x2; - F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); - F x4_over_24_plus_x3_over_6_plus_x2_over_2 = - SaturatingRoundingMultiplyByPOT<-1>( - ((x4_over_4 + x3) * constant_1_over_3) + x2); - return constant_term + - constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2); -} - -template <typename tRawType, int tIntegerBits> -FixedPoint<tRawType, 0> exp_on_negative_values( - FixedPoint<tRawType, tIntegerBits> a) { - typedef FixedPoint<tRawType, tIntegerBits> InputF; - typedef FixedPoint<tRawType, 0> ResultF; - static const int kFractionalBits = InputF::kFractionalBits; - static const int kIntegerBits = InputF::kIntegerBits; - static const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); - InputF mask = kOneQuarter - InputF::FromScalarRaw(1); - InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; - ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( - Rescale<0>(a_mod_quarter_minus_one_quarter)); - tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); - -#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ - if (kIntegerBits > Exponent) { \ - const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ - ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ - result = SelectUsingMask( \ - MaskIfNonZero(BitAnd( \ - remainder, Dup<tRawType>(1 << (kFractionalBits + Exponent)))), \ - result * kMultiplier, result); \ - } - - GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); - GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); - GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); - GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); - GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); - GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); - GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); - -#undef GEMMLOWP_EXP_BARREL_SHIFTER - - if (kIntegerBits > 5) { - static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0; - const InputF clamp = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); - result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); - } - - result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); - return result; -} - -template <typename tRawType> -FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( - FixedPoint<tRawType, 0> a) { - typedef FixedPoint<tRawType, 0> F0; - typedef FixedPoint<tRawType, 2> F2; - F0 half_denominator = RoundingHalfSum(a, F0::One()); - const F2 constant_48_over_17 = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); - const F2 constant_neg_32_over_17 = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); - F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; - for (int i = 0; i < 3; i++) { - F2 half_denominator_times_x = half_denominator * x; - F2 one_minus_half_denominator_times_x = - F2::One() - half_denominator_times_x; - x = x + Rescale<2>(x * one_minus_half_denominator_times_x); - } - return Rescale<0>(x - F2::One()); -} - -template <typename tRawType, int tIntegerBits> -FixedPoint<tRawType, 0> neg_tanh_on_negative_values( - FixedPoint<tRawType, tIntegerBits> a) { - return one_minus_x_over_one_plus_x_for_x_in_0_1( - exp_on_negative_values(ExactMulByPot<1>(a))); -} - -template <typename tRawType, int tIntegerBits> -FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { - typedef FixedPoint<tRawType, tIntegerBits> InputF; - typedef FixedPoint<tRawType, 0> ResultF; - tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); - tRawType mask_if_zero = MaskIfZero(a); - InputF n = SelectUsingMask(mask_if_negative, a, -a); - ResultF t = neg_tanh_on_negative_values(n); - return SelectUsingMask(mask_if_zero, ResultF::Zero(), - SelectUsingMask(mask_if_negative, -t, t)); -} - -} // end namespace gemmlowp - -#ifdef GEMMLOWP_NEON -#include "fixedpoint_neon.h" -#endif - -#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ diff --git a/internal/fixedpoint_neon.h b/internal/fixedpoint_neon.h deleted file mode 100644 index f5688ba..0000000 --- a/internal/fixedpoint_neon.h +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2015 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// fixedpoint_neon.h: optimized NEON specializations of the templates -// in fixedpoint.h. - -#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ -#define GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ - -#include "fixedpoint.h" - -#include <arm_neon.h> - -namespace gemmlowp { - -template <> -inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) { - return vandq_s32(a, b); -} - -template <> -inline int32x4_t BitOr(int32x4_t a, int32x4_t b) { - return vorrq_s32(a, b); -} - -template <> -inline int32x4_t BitXor(int32x4_t a, int32x4_t b) { - return veorq_s32(a, b); -} - -template <> -inline int32x4_t BitNot(int32x4_t a) { - return veorq_s32(a, vdupq_n_s32(-1)); -} - -template <> -inline int32x4_t Add(int32x4_t a, int32x4_t b) { - return vaddq_s32(a, b); -} - -template <> -inline int32x4_t Sub(int32x4_t a, int32x4_t b) { - return vsubq_s32(a, b); -} - -template <> -inline int32x4_t Neg(int32x4_t a) { - return vnegq_s32(a); -} - -template <> -inline int32x4_t ShiftLeft(int32x4_t a, int offset) { - return vshlq_s32(a, vdupq_n_s32(offset)); -} - -template <> -inline int32x4_t ShiftRight(int32x4_t a, int offset) { - return vshlq_s32(a, vdupq_n_s32(-offset)); -} - -template <> -inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val, - int32x4_t else_val) { - return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val); -} - -template <> -inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) { - return vreinterpretq_s32_u32(vceqq_s32(a, b)); -} - -template <> -inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) { - return BitNot(MaskIfEqual(a, b)); -} - -template <> -inline int32x4_t MaskIfZero(int32x4_t a) { - return MaskIfEqual(a, vdupq_n_s32(0)); -} - -template <> -inline int32x4_t MaskIfNonZero(int32x4_t a) { - return vreinterpretq_s32_u32(vtstq_s32(a, a)); -} - -template <> -inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) { - return vreinterpretq_s32_u32(vcgtq_s32(a, b)); -} - -template <> -inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) { - return vreinterpretq_s32_u32(vcgeq_s32(a, b)); -} - -template <> -inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) { - return vreinterpretq_s32_u32(vcltq_s32(a, b)); -} - -template <> -inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) { - return vreinterpretq_s32_u32(vcleq_s32(a, b)); -} - -template <> -inline bool All(int32x4_t a) { - a = vandq_s32(a, vextq_s32(a, a, 1)); - a = vandq_s32(a, vextq_s32(a, a, 2)); - return vgetq_lane_s32(a, 0); -} - -template <> -inline bool Any(int32x4_t a) { - a = vorrq_s32(a, vextq_s32(a, a, 1)); - a = vorrq_s32(a, vextq_s32(a, a, 2)); - return vgetq_lane_s32(a, 0); -} - -template <> -inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) { - return vrhaddq_s32(a, b); -} - -template <> -inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) { - return vqrdmulhq_s32(a, b); -} - -template <int Exponent> -struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> { - static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); } -}; - -template <int Exponent> -struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, -1> { - static int32x4_t eval(int32x4_t x) { return vrshrq_n_s32(x, -Exponent); } -}; - -template <> -struct FixedPointRawTypeTraits<int32x4_t> { - typedef int32_t ScalarRawType; - static const int kLanes = 4; -}; - -template <> -inline int32x4_t Dup<int32x4_t>(int32_t x) { - return vdupq_n_s32(x); -} - -} // end namespace gemmlowp - -#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_ diff --git a/internal/iterator.h b/internal/iterator.h deleted file mode 100644 index 524cb80..0000000 --- a/internal/iterator.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// iterator.h: Const forward iterators for VectorMap and VectorDup that help -// access data in architecture specific way, e.g. 4 elements at a time for NEON. - -#ifndef GEMMLOWP_INTERNAL_ITERATOR_H_ -#define GEMMLOWP_INTERNAL_ITERATOR_H_ - -namespace gemmlowp { - -enum class VectorShape; - -// ConstIterator is a forward only constant iterator that can be made -// architecture specific e.g. to return 4 values at once for NEON. -template <typename VectorType> class ConstIterator { - // Unused default case. -}; - -template <typename tScalar, VectorShape tShape> class VectorMap; - -template <typename tScalar, VectorShape tShape> -class ConstIterator<VectorMap<tScalar, tShape>> { - public: - typedef tScalar Scalar; - ConstIterator(const VectorMap<tScalar, tShape>& vector_map, - const int start_offset) - : pointer_(vector_map.data() + start_offset) {} - const Scalar operator*() const { return *pointer_; } - const Scalar* get() const { return pointer_; } - ConstIterator& operator+=(int inc) { pointer_ += inc; return *this; } - private: - const Scalar* pointer_; -}; - -template <typename tScalar, VectorShape tShape> -ConstIterator<VectorMap<tScalar, tShape>> const_iterator( - const VectorMap<tScalar, tShape>& vector_map, - const int start_offset) { - return ConstIterator<VectorMap<tScalar, tShape>>(vector_map, start_offset); -} - -template <typename tScalar, VectorShape tShape> class VectorDup; - -template <typename tScalar, VectorShape tShape> -class ConstIterator<VectorDup<tScalar, tShape>> { - public: - typedef tScalar Scalar; - ConstIterator(const VectorDup<tScalar, tShape>& vector_dup) - : data_(vector_dup(0)) {} - const Scalar operator*() const { return data_; } - const Scalar* get() const { return &data_; } - ConstIterator& operator+=(int inc) { return *this; } - private: - Scalar data_; -}; - -template <typename tScalar, VectorShape tShape> -ConstIterator<VectorDup<tScalar, tShape>> const_iterator( - const VectorDup<tScalar, tShape>& vector_map, - const int start_offset) { - return ConstIterator<VectorDup<tScalar, tShape>>(vector_map); -} - -} // namespace gemmlowp - -#endif // GEMMLOWP_INTERNAL_ITERATOR_H_ diff --git a/internal/kernel_SSE.h b/internal/kernel_SSE.h deleted file mode 100644 index dba44ea..0000000 --- a/internal/kernel_SSE.h +++ /dev/null @@ -1,512 +0,0 @@ -// Copyright 2015 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// kernel_SSE.h: a collection of Intel SSE optimized kernels. -// Check in kernel_default.h which one(s) are actually used by default. -// Others are mere experiments; they are still covered by tests -// in case they might be useful some day. -// - -#ifndef GEMMLOWP_INTERNAL_KERNEL_SSE_H_ -#define GEMMLOWP_INTERNAL_KERNEL_SSE_H_ - -#include "kernel.h" - -#include <string.h> -#include <cassert> - -namespace gemmlowp { - -#ifdef GEMMLOWP_SSE4_32 -struct SSE4_32_Kernel4x4Depth2 : KernelBase { - typedef KernelFormat< - KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1>, - KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1> > - Format; - - const char* Name() const override { return "SSE, 4x4, depth 2"; } - - void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, - std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, - const std::uint8_t* rhs_ptr, std::size_t start_depth, - std::size_t run_depth) const override { - ScopedProfilingLabel label("optimized kernel"); - assert(dst_row_stride == 1); - std::int32_t run_depth_cells = run_depth / Format::kDepth; - /* Main loop */ - - // A 2x4 cell of Rhs is stored in 16bit in xmm1 . - // A 4x2 block Lhs is stored in 16bit in xmm0. - // A 4x4 block of accumulators is stored in 32bit in xmm4--xmm7. - // - // +-------+-------+-------+-------+ - // |xmm1[0]|xmm1[2]|xmm1[4]|xmm1[6]| - // Rhs +-------+---------------+-------+ - // |xmm1[1]|xmm1[3]|xmm1[5]|xmm1[7]| - // +-------+-------+-------+-------+ - // - // | | | | | - // - // Lhs | | | | | - // - // +--+--+ - - - - +-------+-------+-------+-------+ - // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | - // |xmm0 | (Iter1) | xmm4 | xmm5 | xmm6 | xmm7 | - // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | - // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | - // +--+--+ - - - - +-------+-------+-------+-------+ - // - // Accumulator - - asm volatile( - - // set accumulators to zero. - "pxor %%xmm4 , %%xmm4 \n\t" - "pxor %%xmm5 , %%xmm5 \n\t" - "pxor %%xmm6 , %%xmm6 \n\t" - "pxor %%xmm7 , %%xmm7 \n\t" - - "movl %[run_depth_cells], %%eax\n\t" - "subl $2, %%eax\n\t" - "js outerLoop1%=\n\t" - - // Loop for K unrolled by 4 - "outerLoop2%=:\n\t" - - // K = 1,2 - // RHS cell to xmm1 - "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t" - - // LHS cell - "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm4 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm5 \n\t" - - "prefetcht0 0x80(%[lhs_ptr]) \n\t" - - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm6 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm7 \n\t" - - "prefetcht0 0x80(%[rhs_ptr]) \n\t" - - // K = 3,4 - // RHS cell to xmm1 - "pmovzxbw 0x08(%[rhs_ptr]), %%xmm1\n\t" - - // LHS cell - "pmovzxbw 0x08(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm4 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm5 \n\t" - - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm6 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm7 \n\t" - - "addl $0x10, %[lhs_ptr]\n\t" - "addl $0x10, %[rhs_ptr]\n\t" - - "subl $2, %[run_depth_cells]\n\t" - "jnz outerLoop2%=\n\t" - - "movl %[run_depth_cells], %%eax\n\t" - "decl %%eax\n\t" - "js finish%=\n\t" - - // Loop for K unrolled by 2 - "outerLoop1%=:\n\t" - - // RHS cell to xmm1 - "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t" - - // LHS cell - "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm4 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm5 \n\t" - - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm6 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm7 \n\t" - - "addl $0x08, %[lhs_ptr]\n\t" - "addl $0x08, %[rhs_ptr]\n\t" - - "decl %[run_depth_cells]\n\t" - "jnz outerLoop1%=\n\t" - - "finish%=:\n\t" - - "movl %[dst_col_stride], %%eax\n\t" - "shll $2, %%eax\n\t" - - "movl %[start_depth], %%ecx\n\t" - "test %%ecx, %%ecx\n\t" - "jz storeDst%=\n\t" - - "leal (%%eax,%%eax,0x2), %%ecx\n\t" - "paddd 0x00(%[dst_ptr]) , %%xmm4 \n\t" - "paddd 0x00(%[dst_ptr], %%eax, 1) , %%xmm5 \n\t" - "paddd 0x00(%[dst_ptr], %%eax, 2) , %%xmm6 \n\t" - "paddd 0x00(%[dst_ptr], %%ecx, 1) , %%xmm7 \n\t" - - "storeDst%=:\n\t" - - "leal (%%eax,%%eax,0x2), %%ecx\n\t" - "movdqu %%xmm4 , 0x00(%[dst_ptr]) \n\t" - "movdqu %%xmm5 , 0x00(%[dst_ptr], %%eax, 1)\n\t" - "movdqu %%xmm6 , 0x00(%[dst_ptr], %%eax, 2)\n\t" - "movdqu %%xmm7 , 0x00(%[dst_ptr], %%ecx, 1)\n\t" - - : // outputs - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_ptr] "+r"(dst_ptr) - : // inputs - [start_depth] "g"(start_depth), [dst_col_stride] "g"(dst_col_stride), - [run_depth_cells] "g"(run_depth_cells) - : // clobbers - "cc", "memory", "%xmm0", "%xmm1", "%xmm3", "%xmm2", "%xmm4", "%xmm5", - "%xmm6", "%xmm7", "%eax", "%ecx"); - } -}; -#endif -#ifdef GEMMLOWP_SSE4_64 -struct SSE4_64_Kernel12x4Depth2 : KernelBase { - typedef KernelFormat< - KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 3>, - KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, 1> > - Format; - - const char* Name() const override { return "SSE, 12x4, depth 2"; } - - void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, - std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, - const std::uint8_t* rhs_ptr, std::size_t start_depth, - std::size_t run_depth) const override { - ScopedProfilingLabel label("optimized kernel"); - assert(dst_row_stride == 1); - const std::int64_t run_depth_cells = run_depth / Format::kDepth; - const std::int64_t dst_col_stride_q = dst_col_stride; - - /* Main loop */ - - // A 2x4 cell of Rhs is stored in 16bit in xmm1 . - // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in xmm0, replaced - // every Iteration. - // A 12x4 block of accumulators is stored in 32bit in xmm4--xmm15. - // - // +-------+-------+-------+-------+ - // |xmm1[0]|xmm1[2]|xmm1[4]|xmm1[6]| - // Rhs +-------+---------------+-------+ - // |xmm1[1]|xmm1[3]|xmm1[5]|xmm1[7]| - // +-------+-------+-------+-------+ - // - // | | | | | - // - // Lhs | | | | | - // - // +--+--+ - - - - +-------+-------+-------+-------+ - // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | - // |xmm0 | (Iter1) | xmm4 | xmm5 | xmm6 | xmm7 | - // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | - // |xmm0 | | xmm4 | xmm5 | xmm6 | xmm7 | - // +--+--+ - - - - +-------+-------+-------+-------+ - // |xmm0 | | xmm8 | xmm9 | xmm10 | xmm11 | - // |xmm0 | (Iter2) | xmm8 | xmm9 | xmm10 | xmm11 | - // |xmm0 | | xmm8 | xmm9 | xmm10 | xmm11 | - // |xmm0 | | xmm8 | xmm9 | xmm10 | xmm11 | - // +--+--+ - - - - +-------+-------+-------+-------+ - // |xmm0 | | xmm12 | xmm13 | xmm14 | xmm15 | - // |xmm0 | (Iter3) | xmm12 | xmm13 | xmm14 | xmm15 | - // |xmm0 | | xmm12 | xmm13 | xmm14 | xmm15 | - // |xmm0 | | xmm12 | xmm13 | xmm14 | xmm15 | - // +--+--+ - - - - +-------+-------+-------+-------+ - // - // Accumulator - - asm volatile( - - // Set registers for destination - "movq %[dst_col_stride_q], %%r12\n\t" - "shlq $2, %%r12\n\t" - "leaq (%%r12,%%r12,0x2), %%r13\n\t" - - // Set accumulators to zero. - "pxor %%xmm4 , %%xmm4 \n\t" - "pxor %%xmm5 , %%xmm5 \n\t" - "pxor %%xmm6 , %%xmm6 \n\t" - "pxor %%xmm7 , %%xmm7 \n\t" - "pxor %%xmm8 , %%xmm8 \n\t" - "pxor %%xmm9 , %%xmm9 \n\t" - "pxor %%xmm10 , %%xmm10\n\t" - "pxor %%xmm11 , %%xmm11\n\t" - "pxor %%xmm12 , %%xmm12\n\t" - "pxor %%xmm13 , %%xmm13\n\t" - "pxor %%xmm14 , %%xmm14\n\t" - "pxor %%xmm15 , %%xmm15\n\t" - - "movq %[run_depth_cells], %%r14\n\t" - "subq $2, %%r14\n\t" - "js outerLoop1%=\n\t" - - // Loop for K unrolled by 4 - "outerLoop2%=:\n\t" - - // K = 1,2 - // RHS cell to xmm1 - - "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t" - - // LHS cell - "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm4 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm5 \n\t" - - "prefetcht0 0x80(%[lhs_ptr]) \n\t" - - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm6 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm7 \n\t" - - // next LHS cell - "pmovzxbw 0x08(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm8 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm9 \n\t" - - "prefetcht0 0x80(%[rhs_ptr]) \n\t" - - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm10 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm11 \n\t" - - // next LHS cell - "pmovzxbw 0x10(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm12 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm13 \n\t" - - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm14 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm15 \n\t" - - // K = 3,4 - // RHS cell to xmm1 - "pmovzxbw 0x08(%[rhs_ptr]), %%xmm1\n\t" - - // LHS cell - "pmovzxbw 0x18(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm4 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm5 \n\t" - - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm6 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm7 \n\t" - - // next LHS cell - "pmovzxbw 0x20(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm8 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm9 \n\t" - - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm10 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm11 \n\t" - - // next LHS cell - "pmovzxbw 0x28(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm12 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm13 \n\t" - - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm14 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm15 \n\t" - - "addq $0x30, %[lhs_ptr]\n\t" - "addq $0x10, %[rhs_ptr]\n\t" - - "subq $2, %[run_depth_cells]\n\t" - "jnz outerLoop2%=\n\t" - - "movq %[run_depth_cells], %%r14\n\t" - "decq %%r14\n\t" - "js finish%=\n\t" - - // Loop for K unrolled by 2 - "outerLoop1%=:\n\t" - - // RHS cell to xmm1 - "pmovzxbw (%[rhs_ptr]), %%xmm1\n\t" - - // LHS cell - "pmovzxbw 0x00(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm4 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm5 \n\t" - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm6 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm7 \n\t" - - // next LHS cell - "pmovzxbw 0x08(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm8 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm9 \n\t" - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm10 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm11 \n\t" - - // next LHS cell - "pmovzxbw 0x10(%[lhs_ptr]), %%xmm0\n\t" - "pshufd $0x00,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm12 \n\t" - "pshufd $0x55,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm13 \n\t" - "pshufd $0xaa,%%xmm1,%%xmm2 \n\t" - "pmaddwd %%xmm0, %%xmm2 \n\t" - "paddd %%xmm2, %%xmm14 \n\t" - "pshufd $0xff,%%xmm1,%%xmm3 \n\t" - "pmaddwd %%xmm0, %%xmm3 \n\t" - "paddd %%xmm3, %%xmm15 \n\t" - - "addq $0x18, %[lhs_ptr]\n\t" - "addq $0x08, %[rhs_ptr]\n\t" - - "decq %[run_depth_cells]\n\t" - "jnz outerLoop1%=\n\t" - - "finish%=:\n\t" - - "test %[start_depth], %[start_depth]\n\t" - "jz storeDst%=\n\t" - - "paddd 0x00(%[dst_ptr]) , %%xmm4 \n\t" - "paddd 0x10(%[dst_ptr]) , %%xmm8 \n\t" - "paddd 0x20(%[dst_ptr]) , %%xmm12\n\t" - "paddd 0x00(%[dst_ptr], %%r12, 1) , %%xmm5 \n\t" - "paddd 0x10(%[dst_ptr], %%r12, 1) , %%xmm9 \n\t" - "paddd 0x20(%[dst_ptr], %%r12, 1) , %%xmm13\n\t" - "paddd 0x00(%[dst_ptr], %%r12, 2) , %%xmm6 \n\t" - "paddd 0x10(%[dst_ptr], %%r12, 2) , %%xmm10\n\t" - "paddd 0x20(%[dst_ptr], %%r12, 2) , %%xmm14\n\t" - "paddd 0x00(%[dst_ptr], %%r13, 1) , %%xmm7 \n\t" - "paddd 0x10(%[dst_ptr], %%r13, 1) , %%xmm11\n\t" - "paddd 0x20(%[dst_ptr], %%r13, 1) , %%xmm15\n\t" - - "storeDst%=:\n\t" - - "movdqu %%xmm4 , 0x00(%[dst_ptr]) \n\t" - "movdqu %%xmm8 , 0x10(%[dst_ptr]) \n\t" - "movdqu %%xmm12 , 0x20(%[dst_ptr]) \n\t" - "movdqu %%xmm5 , 0x00(%[dst_ptr], %%r12, 1)\n\t" - "movdqu %%xmm9 , 0x10(%[dst_ptr], %%r12, 1)\n\t" - "movdqu %%xmm13 , 0x20(%[dst_ptr], %%r12, 1)\n\t" - "movdqu %%xmm6 , 0x00(%[dst_ptr], %%r12, 2)\n\t" - "movdqu %%xmm10 , 0x10(%[dst_ptr], %%r12, 2)\n\t" - "movdqu %%xmm14 , 0x20(%[dst_ptr], %%r12, 2)\n\t" - "movdqu %%xmm7 , 0x00(%[dst_ptr], %%r13, 1)\n\t" - "movdqu %%xmm11 , 0x10(%[dst_ptr], %%r13, 1)\n\t" - "movdqu %%xmm15 , 0x20(%[dst_ptr], %%r13, 1)\n\t" - - : // outputs - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_ptr] "+r"(dst_ptr) - : // inputs - [start_depth] "r"(start_depth), - [dst_col_stride_q] "r"(dst_col_stride_q), - [run_depth_cells] "r"(run_depth_cells) - : // clobbers - "cc", "memory", "%xmm0", "%xmm1", "%xmm3", "%xmm2", "%xmm4", "%xmm5", - "%xmm6", "%xmm7", "%xmm8", "%xmm9", "%xmm10", "%r12", "%r13", "%r14", - "%xmm11", "%xmm12", "%xmm13", "%xmm14", "%xmm15"); - } -}; -#endif - -} // namespace gemmlowp - -#endif // GEMMLOWP_INTERNAL_KERNEL_SSE_H_ diff --git a/internal/pack_SSE.h b/internal/pack_SSE.h deleted file mode 100644 index aef4683..0000000 --- a/internal/pack_SSE.h +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2015 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// pack_SSE.h: optimized SSE specializations of the templates in pack.h. - -#ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_ -#define GEMMLOWP_INTERNAL_PACK_SSE_H_ - -#include <smmintrin.h> -#include "pack.h" - -namespace gemmlowp { - -// Requantizes source values pointed by raw_src_ptr in [0..255] range -// to the range specified by BitDepth, [0..((2^bits)-1)]. -// This is in-place requantization, where the input is -// not modified if 8bit integers are used. SSE does not -// have less than 8bit kernels currently. Altought SSE registers -// hold 16 uint8_t elements, only first 8 uint8_t elements are -// requantized. The packing only use first 8 uint8_t elements -// of the SSE registers. Therefore, requantizing all 16 uint8_t -// elements will be wasteful computation. -template <typename QuantizationParams> -void SSERequantize( - __m128i* raw_src_ptr, - ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode>* - rounding_offset_generator) { - static const int kBits = QuantizationParams::BitDepth::kBits; - static const std::uint8_t kMaxVal = (1 << kBits) - 1; - if (kBits == 8) { - return; - } - - std::uint8_t* raw_src_ui8_ptr = (std::uint8_t*)&raw_src_ptr[0]; - - // modify only first 8 elements in the register (see note above) - for (int i = 0; i < 8; ++i) { - std::uint16_t scaled = - static_cast<std::uint16_t>(raw_src_ui8_ptr[i]) * kMaxVal; - std::uint8_t rounding_offset = rounding_offset_generator->get(); - raw_src_ui8_ptr[i] = (scaled + rounding_offset) / 255; - } -} - -// TODO: Add DepthMajorUint8SideMap - -typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> - WidthMajorUint8SideMap; - -template <int Cells> -using WidthMajorSideFormatNCells4x2 = - KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; - -template <typename QuantizationParams, int Cells> -class PackingRegisterBlock< - QuantizationParams, WidthMajorUint8SideMap, - PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > - : public PackingRegisterBlockBase< - QuantizationParams, WidthMajorUint8SideMap, - PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > { - public: - typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; - typedef typename KernelSideFormat::Cell CellFormat; - static const int kCells = KernelSideFormat::kCells; - static const int kCellWidth = CellFormat::kWidth; - static const int kKernelWidth = CellFormat::kWidth * kCells; - static const int kCellDepth = CellFormat::kDepth; - static const int kCellSize = CellFormat::kSize; - - typedef ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode> - RoundingOffsetGenerator; - - void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width, - RoundingOffsetGenerator* rounding_offset_generator) { - std::uint8_t* dst_ptr = dst->current_data(); - const int width_stride = this->complete_src_.width_stride(); - int depth_step = 8; - - __m128i one = _mm_set1_epi16(1); - for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; - cell_start_depth += depth_step) { - for (int cell_start_width = 0; cell_start_width < kKernelWidth; - cell_start_width += kCellWidth) { - std::int32_t* cell_sums_of_each_slice_ptr = - dst->sums_of_each_slice() + start_width + cell_start_width; - const std::uint8_t* src_data = - this->complete_src_.data(cell_start_width, cell_start_depth); - - __m128i xmm1 = - _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0])); - __m128i xmm2 = _mm_loadl_epi64( - reinterpret_cast<const __m128i*>(&src_data[1 * width_stride])); - __m128i xmm3 = _mm_loadl_epi64( - reinterpret_cast<const __m128i*>(&src_data[2 * width_stride])); - __m128i xmm4 = _mm_loadl_epi64( - reinterpret_cast<const __m128i*>(&src_data[3 * width_stride])); - - __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2); - __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31); - - __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4); - __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80); - - __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc); - SSERequantize<QuantizationParams>(&xmm9, rounding_offset_generator); - - __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc); - SSERequantize<QuantizationParams>(&xmm10, rounding_offset_generator); - - _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9); - _mm_storel_epi64( - reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10); - - __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee); - SSERequantize<QuantizationParams>(&xmm11, rounding_offset_generator); - - __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee); - SSERequantize<QuantizationParams>(&xmm12, rounding_offset_generator); - - _mm_storel_epi64( - reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]), - xmm11); - _mm_storel_epi64( - reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]), - xmm12); - - xmm1 = _mm_cvtepu8_epi16(xmm9); - xmm2 = _mm_madd_epi16(xmm1, one); - __m128i sums_of_each_slice_xmm = _mm_loadu_si128( - reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0])); - sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); - - xmm1 = _mm_cvtepu8_epi16(xmm10); - xmm2 = _mm_madd_epi16(xmm1, one); - sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); - - xmm1 = _mm_cvtepu8_epi16(xmm11); - xmm2 = _mm_madd_epi16(xmm1, one); - sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); - - xmm1 = _mm_cvtepu8_epi16(xmm12); - xmm2 = _mm_madd_epi16(xmm1, one); - sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); - - _mm_storeu_si128( - reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]), - sums_of_each_slice_xmm); - dst_ptr += kCellSize; - } - dst_ptr += 3 * kCellSize * kCells; - } - dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); - } -}; - -} // namespace gemmlowp - -#endif // GEMMLOWP_INTERNAL_PACK_SSE_H_ diff --git a/internal/unpack_neon.h b/internal/unpack_neon.h deleted file mode 100644 index 5c9e76a..0000000 --- a/internal/unpack_neon.h +++ /dev/null @@ -1,212 +0,0 @@ -// Copyright 2015 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// unpack_neon.h: optimized NEON specializations of the templates in unpack.h. - -#ifndef GEMMLOWP_INTERNAL_UNPACK_NEON_H_ -#define GEMMLOWP_INTERNAL_UNPACK_NEON_H_ - -#include "output_neon.h" -#include "unpack.h" - -#include <arm_neon.h> - -namespace gemmlowp { - -template <std::uint32_t numerator, std::uint32_t denominator> -int32x4_t RoundingMultiplyByConstantFraction(int32x4_t x) { - static_assert(numerator > 0 && denominator > 0, - "only supporting positive num/denom"); - - if (numerator == denominator) { - return x; - } - - static const std::int32_t int_quotient = - (numerator + denominator / 2) / denominator; - static const std::int32_t remaining_numerator = - numerator - int_quotient * denominator; - static const std::int32_t scaled_remaining_numerator = - static_cast<std::int32_t>( - (static_cast<std::int64_t>(remaining_numerator) * (1ll << 31)) / - denominator); - // Note: vqrdmulh instruction is rounding doubling multiply high. - const int32x4_t remaining_product = - vqrdmulhq_n_s32(x, scaled_remaining_numerator); - - return vmlaq_n_s32(remaining_product, x, int_quotient); -} - -template <typename tScalar, VectorShape tShape> -int32x4_t get_int32x4_t_and_inc( - ConstIterator<VectorMap<tScalar, tShape>>* iterator) { - const int32x4_t result = vld1q_s32(iterator->get()); - *iterator += 4; - return result; -} - -template <typename tScalar, VectorShape tShape> -int32x4_t get_int32x4_t_and_inc( - ConstIterator<VectorDup<tScalar, tShape>>* iterator) { - const int32x4_t result = vdupq_n_s32(**iterator); - // Increment really does nothing for VectorDup. - *iterator += 4; - return result; -} - -template <typename BitDepthParams, typename PackedResultType, - typename OutputScalar, typename LhsOffset, typename RhsOffset, - typename OutputPipelineType> -struct UnpackResultImpl<BitDepthParams, - MatrixMap<OutputScalar, MapOrder::ColMajor>, - PackedResultType, LhsOffset, RhsOffset, - OutputPipelineType> { - typedef MatrixMap<OutputScalar, MapOrder::ColMajor> ResultBlockType; - static void Unpack(ResultBlockType* dst, const MatrixBlockBounds& dst_block, - const PackedResultType& src, int depth, - const std::int32_t* lhs_sums_of_each_slice, - const std::int32_t* rhs_sums_of_each_slice, - const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, - const OutputPipelineType& output_pipeline) { - ScopedProfilingLabel label("optimized path (NEON)"); - assert(dst_block.start_row >= 0); - assert(dst_block.start_row + dst_block.rows <= dst->rows()); - assert(dst_block.start_col >= 0); - assert(dst_block.start_col + dst_block.cols <= dst->cols()); - const int kLhsBits = BitDepthParams::LhsBitDepth::kBits; - const int kRhsBits = BitDepthParams::RhsBitDepth::kBits; - const std::int32_t kLhsMax = (1 << kLhsBits) - 1; - const std::int32_t kRhsMax = (1 << kRhsBits) - 1; - auto src_map = src.Map(); - OutputPipelineExecutor<OutputPipelineType, FragmentInt32x1x1> - output_pipeline_executor_int32x1x1(output_pipeline); - OutputPipelineExecutor<OutputPipelineType, NEONFragmentInt32x4x1> - output_pipeline_executor_int32x4x1(output_pipeline); - OutputPipelineExecutor<OutputPipelineType, NEONFragmentInt32x16x1> - output_pipeline_executor_int32x16x1(output_pipeline); - - for (int c = 0; c < dst_block.cols; c++) { - int c_dst = c + dst_block.start_col; - const std::int32_t* src_ptr = src_map.data(0, c); - const std::int32_t* sums_of_each_slice_ptr = lhs_sums_of_each_slice; - auto lhs_offset_iter = const_iterator(lhs_offset, dst_block.start_row); - const std::int32_t rhs_offset_c = rhs_offset(c_dst); - const std::int32_t rhs_sums_of_each_slice_c = rhs_sums_of_each_slice[c]; - - // Handle 16 values at once for higher performance - int dst_rows_aligned16 = RoundDown<16>(dst_block.rows); - for (int r = 0; r < dst_rows_aligned16; r += 16) { - int r_dst = r + dst_block.start_row; - // Compute the sum of the 4 terms, - // q = term_xx + term_x1 + term_1x_plus_term_11 - // Refer to the generic code in unpack.h. - int32x4_t raw_xx[4]; - for (int i = 0; i < 4; i++) { - raw_xx[i] = vld1q_s32(src_ptr); - src_ptr += 4; - } - int32x4_t raw_x1[4]; - for (int i = 0; i < 4; i++) { - const int32x4_t sum_x1 = vld1q_s32(sums_of_each_slice_ptr); - raw_x1[i] = vmulq_n_s32(sum_x1, rhs_offset_c); - sums_of_each_slice_ptr += 4; - } - int32x4_t raw_1x[4]; - int32x4_t term_11[4]; - for (int i = 0; i < 4; i++) { - const int32x4_t lhs_offsets = get_int32x4_t_and_inc(&lhs_offset_iter); - raw_1x[i] = vmulq_n_s32(lhs_offsets, rhs_sums_of_each_slice_c); - term_11[i] = vmulq_n_s32(lhs_offsets, rhs_offset_c * depth); - } - int32x4_t term_xx[4]; - for (int i = 0; i < 4; i++) { - term_xx[i] = - RoundingMultiplyByConstantFraction<255 * 255, kLhsMax * kRhsMax>( - raw_xx[i]); - } - int32x4_t term_x1[4]; - for (int i = 0; i < 4; i++) { - term_x1[i] = - RoundingMultiplyByConstantFraction<255, kLhsMax>(raw_x1[i]); - } - int32x4_t term_1x[4]; - for (int i = 0; i < 4; i++) { - term_1x[i] = - RoundingMultiplyByConstantFraction<255, kRhsMax>(raw_1x[i]); - } - int32x4x4_t q; - for (int i = 0; i < 4; i++) { - q.val[i] = vaddq_s32(vaddq_s32(term_xx[i], term_x1[i]), - vaddq_s32(term_1x[i], term_11[i])); - } - NEONFragmentInt32x16x1 f(q); - output_pipeline_executor_int32x16x1.Execute(f, dst, r_dst, c_dst); - } - // We have finished handling groups of 16 entries at once; now - // try to handle 4 entries at once. - int dst_rows_aligned4 = RoundDown<4>(dst_block.rows); - for (int r = dst_rows_aligned16; r < dst_rows_aligned4; r += 4) { - int r_dst = r + dst_block.start_row; - // Compute the sum of the 4 terms, - // q = term_xx + term_x1 + term_1x_plus_term_11 - // Refer to the generic code in unpack.h. - const int32x4_t raw_xx = vld1q_s32(src_ptr); - src_ptr += 4; - const int32x4_t term_xx = - RoundingMultiplyByConstantFraction<255 * 255, kLhsMax * kRhsMax>( - raw_xx); - const int32x4_t sum_x1 = vld1q_s32(sums_of_each_slice_ptr); - const int32x4_t raw_x1 = vmulq_n_s32(sum_x1, rhs_offset_c); - sums_of_each_slice_ptr += 4; - const int32x4_t term_x1 = - RoundingMultiplyByConstantFraction<255, kLhsMax>(raw_x1); - const int32x4_t lhs_offsets = get_int32x4_t_and_inc(&lhs_offset_iter); - const int32x4_t raw_1x = - vmulq_n_s32(lhs_offsets, rhs_sums_of_each_slice_c); - const int32x4_t term_1x = - RoundingMultiplyByConstantFraction<255, kRhsMax>(raw_1x); - const int32x4_t term_11 = - vmulq_n_s32(lhs_offsets, rhs_offset_c * depth); - int32x4_t q = vaddq_s32(vaddq_s32(term_xx, term_x1), - vaddq_s32(term_1x, term_11)); - NEONFragmentInt32x4x1 f(q); - output_pipeline_executor_int32x4x1.Execute(f, dst, r_dst, c_dst); - } - // We have finished handling 4 entries at once; now handle - // remaining entries one by one. This scalar code is similar - // to the code in unpack.h, see comments there. - for (int r = dst_rows_aligned4; r < dst_block.rows; r++) { - int r_dst = r + dst_block.start_row; - const std::int32_t raw_xx = src_map(r, c); - const std::int32_t raw_x1 = lhs_sums_of_each_slice[r] * rhs_offset_c; - const std::int32_t raw_1x = - rhs_sums_of_each_slice_c * lhs_offset(r_dst); - const std::int32_t term_xx = - RoundingMultiplyByConstantFraction<255 * 255, kLhsMax * kRhsMax>( - raw_xx); - const std::int32_t term_x1 = - RoundingMultiplyByConstantFraction<255, kLhsMax>(raw_x1); - const std::int32_t term_1x = - RoundingMultiplyByConstantFraction<255, kRhsMax>(raw_1x); - const std::int32_t term_11 = lhs_offset(r) * rhs_offset(c) * depth; - FragmentInt32x1x1 sum = term_xx + term_x1 + term_1x + term_11; - output_pipeline_executor_int32x1x1.Execute(sum, dst, r_dst, c_dst); - } - } - } -}; - -} // namespace gemmlowp - -#endif // GEMMLOWP_INTERNAL_UNPACK_NEON_H_ |