diff options
Diffstat (limited to 'fixedpoint')
-rw-r--r-- | fixedpoint/fixedpoint.h | 68 | ||||
-rw-r--r-- | fixedpoint/fixedpoint_avx.h | 218 | ||||
-rw-r--r-- | fixedpoint/fixedpoint_msa.h | 75 | ||||
-rw-r--r-- | fixedpoint/fixedpoint_neon.h | 30 | ||||
-rw-r--r-- | fixedpoint/fixedpoint_sse.h | 4 |
5 files changed, 361 insertions, 34 deletions
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h index d39341b..58e8050 100644 --- a/fixedpoint/fixedpoint.h +++ b/fixedpoint/fixedpoint.h @@ -18,10 +18,13 @@ #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ +#include <algorithm> #include <cassert> +#include <cmath> +#include <cstdint> #include <limits> -#include "../internal/common.h" +#include "../internal/detect_platform.h" namespace gemmlowp { @@ -47,13 +50,13 @@ struct FixedPointRawTypeTraits {}; template <> struct FixedPointRawTypeTraits<std::int32_t> { typedef std::int32_t ScalarRawType; - static const int kLanes = 1; + static constexpr int kLanes = 1; }; template <> struct FixedPointRawTypeTraits<std::int16_t> { typedef std::int16_t ScalarRawType; - static const int kLanes = 1; + static constexpr int kLanes = 1; }; // Returns a SIMD value duplicating a scalar value across all lanes. @@ -109,11 +112,25 @@ tIntegerType Neg(tIntegerType a) { return -a; } -// Integer arithmetic left-shift, equivalent to multiplying with a -// power of two. Not saturating. Overflow is undefined behavior. -template <typename tIntegerType> -tIntegerType ShiftLeft(tIntegerType a, int offset) { - return a << offset; +// Integer arithmetic left-shift, equivalent to multiplying with a power of two. +// Negative values are OK. In case of overflow, no Undefined +// Behavior, but the results are implementation-defined (in practice, +// they currently are saturated, but we make no commitment to that). The idea +// is that the caller will want to implement the overflowing cases with +// saturation with compare-and-mask, so we don't care about the results +// in the overflow case, we just want to avoid undefined behavior. +// +// tIntegerType may be int32 or any narrower signed type. +template <typename tIntegerType, typename OffsetType> +tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) { + const std::int64_t wide_a = static_cast<std::int64_t>(a); + const std::int64_t wide_shifted = wide_a * (1 << offset); + const auto min = std::numeric_limits<tIntegerType>::min(); + const auto max = std::numeric_limits<tIntegerType>::max(); + return wide_shifted < min + ? min + : wide_shifted > max ? max + : static_cast<tIntegerType>(wide_shifted); } // Integer arithmetic right-shift. Not rounding. @@ -137,7 +154,7 @@ tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, // input scalar is non-zero. template <typename tIntegerType> tIntegerType MaskIfNonZero(tIntegerType a) { - static const tIntegerType zero = 0; + static constexpr tIntegerType zero = 0; return a ? BitNot(zero) : zero; } @@ -211,6 +228,7 @@ bool Any(tIntegerType a) { template <typename IntegerType> IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); + (void)b; return a; } @@ -235,6 +253,7 @@ inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { template <typename IntegerType> IntegerType SaturatingAdd(IntegerType a, IntegerType b) { static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); + (void)b; return a; } @@ -244,7 +263,9 @@ inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { std::int32_t a32 = a; std::int32_t b32 = b; std::int32_t sum = a32 + b32; - return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum))); + return static_cast<std::int16_t>( + std::min(static_cast<std::int32_t>(32767), + std::max(static_cast<std::int32_t>(-32768), sum))); } // Returns a+b, saturating if the integers are 16bit or narrower, @@ -298,6 +319,7 @@ IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { template <typename IntegerType> IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); + (void)b; return a; } @@ -331,8 +353,8 @@ inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, // Correctly-rounded-to-nearest division by a power-of-two. // Also known as a rounding arithmetic right shift. -template <typename IntegerType> -inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) { +template <typename IntegerType, typename ExponentType> +inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) { assert(exponent >= 0); assert(exponent <= 31); const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); @@ -432,9 +454,9 @@ class FixedPoint { typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; typedef typename RawTypeTraits::ScalarRawType ScalarRawType; - static const int kTotalBits = 8 * sizeof(ScalarRawType); - static const int kIntegerBits = tIntegerBits; - static const int kFractionalBits = kTotalBits - 1 - kIntegerBits; + static constexpr int kTotalBits = 8 * sizeof(ScalarRawType); + static constexpr int kIntegerBits = tIntegerBits; + static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); @@ -474,7 +496,7 @@ class FixedPoint { template <int Exponent> static FixedPoint ConstantPOT() { - static const int kOffset = kFractionalBits + Exponent; + static constexpr int kOffset = kFractionalBits + Exponent; static_assert( kOffset < 31, "Constant not exactly representable in this fixed-point format"); @@ -645,7 +667,7 @@ double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> FixedPoint<tRawType, tIntegerBitsDst> Rescale( FixedPoint<tRawType, tIntegerBitsSrc> x) { - static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst; + static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; FixedPoint<tRawType, tIntegerBitsDst> result; result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); return result; @@ -725,9 +747,9 @@ FixedPoint<tRawType, 0> exp_on_negative_values( FixedPoint<tRawType, tIntegerBits> a) { typedef FixedPoint<tRawType, tIntegerBits> InputF; typedef FixedPoint<tRawType, 0> ResultF; - static const int kFractionalBits = InputF::kFractionalBits; - static const int kIntegerBits = InputF::kIntegerBits; - static const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); + static constexpr int kFractionalBits = InputF::kFractionalBits; + static constexpr int kIntegerBits = InputF::kIntegerBits; + const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); InputF mask = kOneQuarter - InputF::FromScalarRaw(1); InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( @@ -755,10 +777,10 @@ FixedPoint<tRawType, 0> exp_on_negative_values( #undef GEMMLOWP_EXP_BARREL_SHIFTER + static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0; if (kIntegerBits > 5) { - static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0; const InputF clamp = - GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0); result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); } @@ -867,6 +889,8 @@ FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { #ifdef GEMMLOWP_NEON #include "./fixedpoint_neon.h" +#elif defined(GEMMLOWP_AVX2) +#include "./fixedpoint_avx.h" #elif defined(GEMMLOWP_SSE4) #include "./fixedpoint_sse.h" #elif defined(GEMMLOWP_MSA) diff --git a/fixedpoint/fixedpoint_avx.h b/fixedpoint/fixedpoint_avx.h new file mode 100644 index 0000000..1816386 --- /dev/null +++ b/fixedpoint/fixedpoint_avx.h @@ -0,0 +1,218 @@ +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// fixedpoint_avx.h: optimized avx specializations of the templates +// in fixedpoint.h. + +#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ +#define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ + +#include <smmintrin.h> +#include "fixedpoint.h" +#include "fixedpoint_sse.h" + +namespace gemmlowp { + +template <> +struct FixedPointRawTypeTraits<__m256i> { + typedef std::int32_t ScalarRawType; + static const int kLanes = 4; +}; + +template <> +inline __m256i BitAnd(__m256i a, __m256i b) { + return _mm256_and_si256(a, b); +} + +template <> +inline __m256i BitOr(__m256i a, __m256i b) { + return _mm256_or_si256(a, b); +} + +template <> +inline __m256i BitXor(__m256i a, __m256i b) { + return _mm256_xor_si256(a, b); +} + +template <> +inline __m256i BitNot(__m256i a) { + return _mm256_andnot_si256(a, _mm256_set1_epi32(-1)); +} + +template <> +inline __m256i Add(__m256i a, __m256i b) { + return _mm256_add_epi32(a, b); +} + +template <> +inline __m256i Mul(__m256i a, __m256i b) { + return _mm256_mullo_epi32(a, b); +} + +template <> +inline __m256i Sub(__m256i a, __m256i b) { + return _mm256_sub_epi32(a, b); +} + +template <> +inline __m256i Neg(__m256i a) { + return _mm256_sign_epi32(a, _mm256_set1_epi32(-1)); +} + +template <> +inline __m256i ShiftLeft(__m256i a, int offset) { + return _mm256_slli_epi32(a, offset); +} + +template <> +inline __m256i ShiftRight(__m256i a, int offset) { + return _mm256_srai_epi32(a, offset); +} + +template <> +inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val, + __m256i else_val) { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(else_val), + _mm256_castsi256_ps(then_val), + _mm256_castsi256_ps(if_mask))); +} + +template <> +inline __m256i MaskIfEqual(__m256i a, __m256i b) { + return _mm256_cmpeq_epi32(a, b); +} + +template <> +inline __m256i MaskIfNotEqual(__m256i a, __m256i b) { + return BitNot(MaskIfEqual(a, b)); +} + +template <> +inline __m256i MaskIfZero(__m256i a) { + return MaskIfEqual(a, _mm256_set1_epi32(0)); +} + +template <> +inline __m256i MaskIfNonZero(__m256i a) { + return MaskIfNotEqual(a, _mm256_set1_epi32(0)); +} + +template <> +inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) { + return _mm256_cmpgt_epi32(a, b); +} + +template <> +inline __m256i MaskIfLessThan(__m256i a, __m256i b) { + return _mm256_cmpgt_epi32(b, a); +} + +template <> +inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) { + return BitNot(MaskIfLessThan(a, b)); +} + +template <> +inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) { + return BitNot(MaskIfGreaterThan(a, b)); +} + +/* Assumptions: + - All and Any are used on masks. + - masks are all_ones for true lanes, all_zeroes otherwise. +Hence, All means all 128bits set, and Any means any bit set. +*/ + +template <> +inline bool All(__m256i a) { + return _mm256_testc_si256(a, a); +} + +template <> +inline bool Any(__m256i a) { + return BitNot(_mm256_testz_si256(a, a)); +} + +template <> +inline __m256i RoundingHalfSum(__m256i a, __m256i b) { + /* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */ + /* We divide the inputs before the add to avoid the overflow and costly test + */ + /* of checking if an overflow occured on signed add */ + /* round_bit_mask = _mm_set1_epi32(1); */ + /* a_over_2 = _mm_srai_epi32(a, 1); */ + /* b_over_2 = _mm_srai_epi32(b, 1); */ + /* sum = Add(a_over_2, b_over_2); */ + /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */ + /* return Add(sum, round_bit); */ + + /* Other possibility detecting overflow and xor the sign if an overflow + * happened*/ + __m256i one, sign_bit_mask, sum, rounded_half_sum, overflow, result; + one = _mm256_set1_epi32(1); + sign_bit_mask = _mm256_set1_epi32(0x80000000); + sum = Add(a, b); + rounded_half_sum = _mm256_srai_epi32(Add(sum, one), 1); + overflow = + BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)), + sign_bit_mask); + result = BitXor(rounded_half_sum, overflow); + return result; +} + +template <> +inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) { + __m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3; + __m256i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded; + __m256i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result; + __m256i nudge; + + // saturation only happen if a == b == INT_MIN + min = _mm256_set1_epi32(std::numeric_limits<std::int32_t>::min()); + saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min)); + + // a = a0 | a1 | a2 | a3 + // b = b0 | b1 | b2 | b3 + a0_a2 = a; + a1_a3 = _mm256_srli_si256(a, 4); + b0_b2 = b; + b1_b3 = _mm256_srli_si256(b, 4); + + a0b0_a2b2 = _mm256_mul_epi32(a0_a2, b0_b2); + a1b1_a3b3 = _mm256_mul_epi32(a1_a3, b1_b3); + + // do the rounding and take into account that it will be doubled + nudge = _mm256_set1_epi64x(1 << 30); + a0b0_a2b2_rounded = _mm256_add_epi64(a0b0_a2b2, nudge); + a1b1_a3b3_rounded = _mm256_add_epi64(a1b1_a3b3, nudge); + + // do the doubling + a0b0_a2b2_rounded_2x = _mm256_slli_epi64(a0b0_a2b2_rounded, 1); + a1b1_a3b3_rounded_2x = _mm256_slli_epi64(a1b1_a3b3_rounded, 1); + + // get the high part of the products + result = _mm256_blend_epi16(_mm256_srli_si256(a0b0_a2b2_rounded_2x, 4), + a1b1_a3b3_rounded_2x, 0xcc); + + // saturate those which overflowed + return SelectUsingMask(saturation_mask, min, result); +} + +template <> +inline __m256i Dup<__m256i>(std::int32_t x) { + return _mm256_set1_epi32(x); +} + +} // end namespace gemmlowp + +#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_ diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h index c7a110c..b17f32a 100644 --- a/fixedpoint/fixedpoint_msa.h +++ b/fixedpoint/fixedpoint_msa.h @@ -25,13 +25,13 @@ namespace gemmlowp { template <> struct FixedPointRawTypeTraits<v4i32> { typedef std::int32_t ScalarRawType; - static const int kLanes = 4; + static constexpr int kLanes = 4; }; template <> struct FixedPointRawTypeTraits<v8i16> { typedef std::int16_t ScalarRawType; - static const int kLanes = 8; + static constexpr int kLanes = 8; }; template <> @@ -326,11 +326,71 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> { } }; -// TODO: possibly implement: -// template <> v4i32 RoundingDivideByPOT(v4i32, int) -// template <> v8i16 RoundingDivideByPOT(v8i16, int) -// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> -// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> +template <int Exponent> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> { + static v4i32 eval(v4i32 x) { + static_assert(-31 <= Exponent && Exponent <= -1, ""); + // Isolate the sign bits. + v4i32 sign = __builtin_msa_srli_w(x, 31); + // Decrement the negative elements by 1 (with saturation). + x = __builtin_msa_subs_s_w(x, sign); + // Arithmetic shift right with rounding. + // The srari instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srari_w(x, -Exponent); + } +}; + +template <int Exponent> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> { + static v8i16 eval(v8i16 x) { + static_assert(-15 <= Exponent && Exponent <= -1, ""); + // Isolate the sign bits. + v8i16 sign = __builtin_msa_srli_h(x, 15); + // Decrement the negative elements by 1 (with saturation). + x = __builtin_msa_subs_s_h(x, sign); + // Arithmetic shift right with rounding. + // The srari instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srari_h(x, -Exponent); + } +}; + +template <> +inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) { + v4i32 e = __builtin_msa_fill_w(exponent); + // Isolate the sign bits. + v4i32 sign = __builtin_msa_srli_w(x, 31); + // Reset them to 0 if exponent is 0. + sign = __builtin_msa_min_s_w(sign, e); + // Decrement the negative elements by 1 (with saturation) + // if exponent is non-zero. + x = __builtin_msa_subs_s_w(x, sign); + // Arithmetic shift right with rounding. + // The srar instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srar_w(x, e); +} + +template <> +inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) { + v8i16 e = __builtin_msa_fill_h(exponent); + // Isolate the sign bits. + v8i16 sign = __builtin_msa_srli_h(x, 15); + // Reset them to 0 if exponent is 0. + sign = __builtin_msa_min_s_h(sign, e); + // Decrement the negative elements by 1 (with saturation) + // if exponent is non-zero. + x = __builtin_msa_subs_s_h(x, sign); + // Arithmetic shift right with rounding. + // The srar instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srar_h(x, e); +} template <> inline v4i32 Dup<v4i32>(std::int32_t x) { @@ -346,7 +406,6 @@ inline v8i16 Dup<v8i16>(std::int16_t x) { template <> inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) { return __builtin_msa_adds_s_h(a, b); - return a; } } // end namespace gemmlowp diff --git a/fixedpoint/fixedpoint_neon.h b/fixedpoint/fixedpoint_neon.h index 92b349b..4dab6c9 100644 --- a/fixedpoint/fixedpoint_neon.h +++ b/fixedpoint/fixedpoint_neon.h @@ -25,13 +25,13 @@ namespace gemmlowp { template <> struct FixedPointRawTypeTraits<int32x4_t> { typedef std::int32_t ScalarRawType; - static const int kLanes = 4; + static constexpr int kLanes = 4; }; template <> struct FixedPointRawTypeTraits<int16x8_t> { typedef std::int16_t ScalarRawType; - static const int kLanes = 8; + static constexpr int kLanes = 8; }; template <> @@ -115,6 +115,16 @@ inline int16x8_t ShiftLeft(int16x8_t a, int offset) { } template <> +inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) { + return vshlq_s32(a, offset); +} + +template <> +inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) { + return vshlq_s16(a, offset); +} + +template <> inline int32x4_t ShiftRight(int32x4_t a, int offset) { return vshlq_s32(a, vdupq_n_s32(-offset)); } @@ -282,6 +292,22 @@ inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) { return vrshlq_s16(fixed_up_x, shift_vec); } +template <> +inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) { + const int32x4_t shift_vec = vnegq_s32(exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +template <> +inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) { + const int16x8_t shift_vec = vnegq_s16(exponent); + const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15); + const int16x8_t fixed_up_x = vqaddq_s16(x, fixup); + return vrshlq_s16(fixed_up_x, shift_vec); +} + template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> { static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); } diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h index ba990f0..a1fae32 100644 --- a/fixedpoint/fixedpoint_sse.h +++ b/fixedpoint/fixedpoint_sse.h @@ -42,13 +42,13 @@ struct int16x8_m128i { template <> struct FixedPointRawTypeTraits<__m128i> { typedef std::int32_t ScalarRawType; - static const int kLanes = 4; + static constexpr int kLanes = 4; }; template <> struct FixedPointRawTypeTraits<int16x8_m128i> { typedef std::int16_t ScalarRawType; - static const int kLanes = 8; + static constexpr int kLanes = 8; }; template <> |