aboutsummaryrefslogtreecommitdiff
path: root/fixedpoint
diff options
context:
space:
mode:
Diffstat (limited to 'fixedpoint')
-rw-r--r--fixedpoint/fixedpoint.h68
-rw-r--r--fixedpoint/fixedpoint_avx.h218
-rw-r--r--fixedpoint/fixedpoint_msa.h75
-rw-r--r--fixedpoint/fixedpoint_neon.h30
-rw-r--r--fixedpoint/fixedpoint_sse.h4
5 files changed, 361 insertions, 34 deletions
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h
index d39341b..58e8050 100644
--- a/fixedpoint/fixedpoint.h
+++ b/fixedpoint/fixedpoint.h
@@ -18,10 +18,13 @@
#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
#define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
+#include <algorithm>
#include <cassert>
+#include <cmath>
+#include <cstdint>
#include <limits>
-#include "../internal/common.h"
+#include "../internal/detect_platform.h"
namespace gemmlowp {
@@ -47,13 +50,13 @@ struct FixedPointRawTypeTraits {};
template <>
struct FixedPointRawTypeTraits<std::int32_t> {
typedef std::int32_t ScalarRawType;
- static const int kLanes = 1;
+ static constexpr int kLanes = 1;
};
template <>
struct FixedPointRawTypeTraits<std::int16_t> {
typedef std::int16_t ScalarRawType;
- static const int kLanes = 1;
+ static constexpr int kLanes = 1;
};
// Returns a SIMD value duplicating a scalar value across all lanes.
@@ -109,11 +112,25 @@ tIntegerType Neg(tIntegerType a) {
return -a;
}
-// Integer arithmetic left-shift, equivalent to multiplying with a
-// power of two. Not saturating. Overflow is undefined behavior.
-template <typename tIntegerType>
-tIntegerType ShiftLeft(tIntegerType a, int offset) {
- return a << offset;
+// Integer arithmetic left-shift, equivalent to multiplying with a power of two.
+// Negative values are OK. In case of overflow, no Undefined
+// Behavior, but the results are implementation-defined (in practice,
+// they currently are saturated, but we make no commitment to that). The idea
+// is that the caller will want to implement the overflowing cases with
+// saturation with compare-and-mask, so we don't care about the results
+// in the overflow case, we just want to avoid undefined behavior.
+//
+// tIntegerType may be int32 or any narrower signed type.
+template <typename tIntegerType, typename OffsetType>
+tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) {
+ const std::int64_t wide_a = static_cast<std::int64_t>(a);
+ const std::int64_t wide_shifted = wide_a * (1 << offset);
+ const auto min = std::numeric_limits<tIntegerType>::min();
+ const auto max = std::numeric_limits<tIntegerType>::max();
+ return wide_shifted < min
+ ? min
+ : wide_shifted > max ? max
+ : static_cast<tIntegerType>(wide_shifted);
}
// Integer arithmetic right-shift. Not rounding.
@@ -137,7 +154,7 @@ tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
// input scalar is non-zero.
template <typename tIntegerType>
tIntegerType MaskIfNonZero(tIntegerType a) {
- static const tIntegerType zero = 0;
+ static constexpr tIntegerType zero = 0;
return a ? BitNot(zero) : zero;
}
@@ -211,6 +228,7 @@ bool Any(tIntegerType a) {
template <typename IntegerType>
IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ (void)b;
return a;
}
@@ -235,6 +253,7 @@ inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
template <typename IntegerType>
IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ (void)b;
return a;
}
@@ -244,7 +263,9 @@ inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
std::int32_t a32 = a;
std::int32_t b32 = b;
std::int32_t sum = a32 + b32;
- return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
+ return static_cast<std::int16_t>(
+ std::min(static_cast<std::int32_t>(32767),
+ std::max(static_cast<std::int32_t>(-32768), sum)));
}
// Returns a+b, saturating if the integers are 16bit or narrower,
@@ -298,6 +319,7 @@ IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
template <typename IntegerType>
IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ (void)b;
return a;
}
@@ -331,8 +353,8 @@ inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
// Correctly-rounded-to-nearest division by a power-of-two.
// Also known as a rounding arithmetic right shift.
-template <typename IntegerType>
-inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
+template <typename IntegerType, typename ExponentType>
+inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) {
assert(exponent >= 0);
assert(exponent <= 31);
const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
@@ -432,9 +454,9 @@ class FixedPoint {
typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
- static const int kTotalBits = 8 * sizeof(ScalarRawType);
- static const int kIntegerBits = tIntegerBits;
- static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
+ static constexpr int kTotalBits = 8 * sizeof(ScalarRawType);
+ static constexpr int kIntegerBits = tIntegerBits;
+ static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits;
static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
"bad IntegerBits");
@@ -474,7 +496,7 @@ class FixedPoint {
template <int Exponent>
static FixedPoint ConstantPOT() {
- static const int kOffset = kFractionalBits + Exponent;
+ static constexpr int kOffset = kFractionalBits + Exponent;
static_assert(
kOffset < 31,
"Constant not exactly representable in this fixed-point format");
@@ -645,7 +667,7 @@ double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
FixedPoint<tRawType, tIntegerBitsDst> Rescale(
FixedPoint<tRawType, tIntegerBitsSrc> x) {
- static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
+ static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
FixedPoint<tRawType, tIntegerBitsDst> result;
result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
return result;
@@ -725,9 +747,9 @@ FixedPoint<tRawType, 0> exp_on_negative_values(
FixedPoint<tRawType, tIntegerBits> a) {
typedef FixedPoint<tRawType, tIntegerBits> InputF;
typedef FixedPoint<tRawType, 0> ResultF;
- static const int kFractionalBits = InputF::kFractionalBits;
- static const int kIntegerBits = InputF::kIntegerBits;
- static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
+ static constexpr int kFractionalBits = InputF::kFractionalBits;
+ static constexpr int kIntegerBits = InputF::kIntegerBits;
+ const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
@@ -755,10 +777,10 @@ FixedPoint<tRawType, 0> exp_on_negative_values(
#undef GEMMLOWP_EXP_BARREL_SHIFTER
+ static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
if (kIntegerBits > 5) {
- static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
const InputF clamp =
- GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0);
result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
}
@@ -867,6 +889,8 @@ FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
#ifdef GEMMLOWP_NEON
#include "./fixedpoint_neon.h"
+#elif defined(GEMMLOWP_AVX2)
+#include "./fixedpoint_avx.h"
#elif defined(GEMMLOWP_SSE4)
#include "./fixedpoint_sse.h"
#elif defined(GEMMLOWP_MSA)
diff --git a/fixedpoint/fixedpoint_avx.h b/fixedpoint/fixedpoint_avx.h
new file mode 100644
index 0000000..1816386
--- /dev/null
+++ b/fixedpoint/fixedpoint_avx.h
@@ -0,0 +1,218 @@
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// fixedpoint_avx.h: optimized avx specializations of the templates
+// in fixedpoint.h.
+
+#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
+#define GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
+
+#include <smmintrin.h>
+#include "fixedpoint.h"
+#include "fixedpoint_sse.h"
+
+namespace gemmlowp {
+
+template <>
+struct FixedPointRawTypeTraits<__m256i> {
+ typedef std::int32_t ScalarRawType;
+ static const int kLanes = 4;
+};
+
+template <>
+inline __m256i BitAnd(__m256i a, __m256i b) {
+ return _mm256_and_si256(a, b);
+}
+
+template <>
+inline __m256i BitOr(__m256i a, __m256i b) {
+ return _mm256_or_si256(a, b);
+}
+
+template <>
+inline __m256i BitXor(__m256i a, __m256i b) {
+ return _mm256_xor_si256(a, b);
+}
+
+template <>
+inline __m256i BitNot(__m256i a) {
+ return _mm256_andnot_si256(a, _mm256_set1_epi32(-1));
+}
+
+template <>
+inline __m256i Add(__m256i a, __m256i b) {
+ return _mm256_add_epi32(a, b);
+}
+
+template <>
+inline __m256i Mul(__m256i a, __m256i b) {
+ return _mm256_mullo_epi32(a, b);
+}
+
+template <>
+inline __m256i Sub(__m256i a, __m256i b) {
+ return _mm256_sub_epi32(a, b);
+}
+
+template <>
+inline __m256i Neg(__m256i a) {
+ return _mm256_sign_epi32(a, _mm256_set1_epi32(-1));
+}
+
+template <>
+inline __m256i ShiftLeft(__m256i a, int offset) {
+ return _mm256_slli_epi32(a, offset);
+}
+
+template <>
+inline __m256i ShiftRight(__m256i a, int offset) {
+ return _mm256_srai_epi32(a, offset);
+}
+
+template <>
+inline __m256i SelectUsingMask(__m256i if_mask, __m256i then_val,
+ __m256i else_val) {
+ return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(else_val),
+ _mm256_castsi256_ps(then_val),
+ _mm256_castsi256_ps(if_mask)));
+}
+
+template <>
+inline __m256i MaskIfEqual(__m256i a, __m256i b) {
+ return _mm256_cmpeq_epi32(a, b);
+}
+
+template <>
+inline __m256i MaskIfNotEqual(__m256i a, __m256i b) {
+ return BitNot(MaskIfEqual(a, b));
+}
+
+template <>
+inline __m256i MaskIfZero(__m256i a) {
+ return MaskIfEqual(a, _mm256_set1_epi32(0));
+}
+
+template <>
+inline __m256i MaskIfNonZero(__m256i a) {
+ return MaskIfNotEqual(a, _mm256_set1_epi32(0));
+}
+
+template <>
+inline __m256i MaskIfGreaterThan(__m256i a, __m256i b) {
+ return _mm256_cmpgt_epi32(a, b);
+}
+
+template <>
+inline __m256i MaskIfLessThan(__m256i a, __m256i b) {
+ return _mm256_cmpgt_epi32(b, a);
+}
+
+template <>
+inline __m256i MaskIfGreaterThanOrEqual(__m256i a, __m256i b) {
+ return BitNot(MaskIfLessThan(a, b));
+}
+
+template <>
+inline __m256i MaskIfLessThanOrEqual(__m256i a, __m256i b) {
+ return BitNot(MaskIfGreaterThan(a, b));
+}
+
+/* Assumptions:
+ - All and Any are used on masks.
+ - masks are all_ones for true lanes, all_zeroes otherwise.
+Hence, All means all 128bits set, and Any means any bit set.
+*/
+
+template <>
+inline bool All(__m256i a) {
+ return _mm256_testc_si256(a, a);
+}
+
+template <>
+inline bool Any(__m256i a) {
+ return BitNot(_mm256_testz_si256(a, a));
+}
+
+template <>
+inline __m256i RoundingHalfSum(__m256i a, __m256i b) {
+ /* __m256i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */
+ /* We divide the inputs before the add to avoid the overflow and costly test
+ */
+ /* of checking if an overflow occured on signed add */
+ /* round_bit_mask = _mm_set1_epi32(1); */
+ /* a_over_2 = _mm_srai_epi32(a, 1); */
+ /* b_over_2 = _mm_srai_epi32(b, 1); */
+ /* sum = Add(a_over_2, b_over_2); */
+ /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */
+ /* return Add(sum, round_bit); */
+
+ /* Other possibility detecting overflow and xor the sign if an overflow
+ * happened*/
+ __m256i one, sign_bit_mask, sum, rounded_half_sum, overflow, result;
+ one = _mm256_set1_epi32(1);
+ sign_bit_mask = _mm256_set1_epi32(0x80000000);
+ sum = Add(a, b);
+ rounded_half_sum = _mm256_srai_epi32(Add(sum, one), 1);
+ overflow =
+ BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)),
+ sign_bit_mask);
+ result = BitXor(rounded_half_sum, overflow);
+ return result;
+}
+
+template <>
+inline __m256i SaturatingRoundingDoublingHighMul(__m256i a, __m256i b) {
+ __m256i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3;
+ __m256i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded;
+ __m256i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result;
+ __m256i nudge;
+
+ // saturation only happen if a == b == INT_MIN
+ min = _mm256_set1_epi32(std::numeric_limits<std::int32_t>::min());
+ saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min));
+
+ // a = a0 | a1 | a2 | a3
+ // b = b0 | b1 | b2 | b3
+ a0_a2 = a;
+ a1_a3 = _mm256_srli_si256(a, 4);
+ b0_b2 = b;
+ b1_b3 = _mm256_srli_si256(b, 4);
+
+ a0b0_a2b2 = _mm256_mul_epi32(a0_a2, b0_b2);
+ a1b1_a3b3 = _mm256_mul_epi32(a1_a3, b1_b3);
+
+ // do the rounding and take into account that it will be doubled
+ nudge = _mm256_set1_epi64x(1 << 30);
+ a0b0_a2b2_rounded = _mm256_add_epi64(a0b0_a2b2, nudge);
+ a1b1_a3b3_rounded = _mm256_add_epi64(a1b1_a3b3, nudge);
+
+ // do the doubling
+ a0b0_a2b2_rounded_2x = _mm256_slli_epi64(a0b0_a2b2_rounded, 1);
+ a1b1_a3b3_rounded_2x = _mm256_slli_epi64(a1b1_a3b3_rounded, 1);
+
+ // get the high part of the products
+ result = _mm256_blend_epi16(_mm256_srli_si256(a0b0_a2b2_rounded_2x, 4),
+ a1b1_a3b3_rounded_2x, 0xcc);
+
+ // saturate those which overflowed
+ return SelectUsingMask(saturation_mask, min, result);
+}
+
+template <>
+inline __m256i Dup<__m256i>(std::int32_t x) {
+ return _mm256_set1_epi32(x);
+}
+
+} // end namespace gemmlowp
+
+#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_AVX_H_
diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h
index c7a110c..b17f32a 100644
--- a/fixedpoint/fixedpoint_msa.h
+++ b/fixedpoint/fixedpoint_msa.h
@@ -25,13 +25,13 @@ namespace gemmlowp {
template <>
struct FixedPointRawTypeTraits<v4i32> {
typedef std::int32_t ScalarRawType;
- static const int kLanes = 4;
+ static constexpr int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<v8i16> {
typedef std::int16_t ScalarRawType;
- static const int kLanes = 8;
+ static constexpr int kLanes = 8;
};
template <>
@@ -326,11 +326,71 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> {
}
};
-// TODO: possibly implement:
-// template <> v4i32 RoundingDivideByPOT(v4i32, int)
-// template <> v8i16 RoundingDivideByPOT(v8i16, int)
-// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1>
-// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1>
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> {
+ static v4i32 eval(v4i32 x) {
+ static_assert(-31 <= Exponent && Exponent <= -1, "");
+ // Isolate the sign bits.
+ v4i32 sign = __builtin_msa_srli_w(x, 31);
+ // Decrement the negative elements by 1 (with saturation).
+ x = __builtin_msa_subs_s_w(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srari instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srari_w(x, -Exponent);
+ }
+};
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> {
+ static v8i16 eval(v8i16 x) {
+ static_assert(-15 <= Exponent && Exponent <= -1, "");
+ // Isolate the sign bits.
+ v8i16 sign = __builtin_msa_srli_h(x, 15);
+ // Decrement the negative elements by 1 (with saturation).
+ x = __builtin_msa_subs_s_h(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srari instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srari_h(x, -Exponent);
+ }
+};
+
+template <>
+inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) {
+ v4i32 e = __builtin_msa_fill_w(exponent);
+ // Isolate the sign bits.
+ v4i32 sign = __builtin_msa_srli_w(x, 31);
+ // Reset them to 0 if exponent is 0.
+ sign = __builtin_msa_min_s_w(sign, e);
+ // Decrement the negative elements by 1 (with saturation)
+ // if exponent is non-zero.
+ x = __builtin_msa_subs_s_w(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srar instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srar_w(x, e);
+}
+
+template <>
+inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) {
+ v8i16 e = __builtin_msa_fill_h(exponent);
+ // Isolate the sign bits.
+ v8i16 sign = __builtin_msa_srli_h(x, 15);
+ // Reset them to 0 if exponent is 0.
+ sign = __builtin_msa_min_s_h(sign, e);
+ // Decrement the negative elements by 1 (with saturation)
+ // if exponent is non-zero.
+ x = __builtin_msa_subs_s_h(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srar instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srar_h(x, e);
+}
template <>
inline v4i32 Dup<v4i32>(std::int32_t x) {
@@ -346,7 +406,6 @@ inline v8i16 Dup<v8i16>(std::int16_t x) {
template <>
inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) {
return __builtin_msa_adds_s_h(a, b);
- return a;
}
} // end namespace gemmlowp
diff --git a/fixedpoint/fixedpoint_neon.h b/fixedpoint/fixedpoint_neon.h
index 92b349b..4dab6c9 100644
--- a/fixedpoint/fixedpoint_neon.h
+++ b/fixedpoint/fixedpoint_neon.h
@@ -25,13 +25,13 @@ namespace gemmlowp {
template <>
struct FixedPointRawTypeTraits<int32x4_t> {
typedef std::int32_t ScalarRawType;
- static const int kLanes = 4;
+ static constexpr int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<int16x8_t> {
typedef std::int16_t ScalarRawType;
- static const int kLanes = 8;
+ static constexpr int kLanes = 8;
};
template <>
@@ -115,6 +115,16 @@ inline int16x8_t ShiftLeft(int16x8_t a, int offset) {
}
template <>
+inline int32x4_t ShiftLeft(int32x4_t a, int32x4_t offset) {
+ return vshlq_s32(a, offset);
+}
+
+template <>
+inline int16x8_t ShiftLeft(int16x8_t a, int16x8_t offset) {
+ return vshlq_s16(a, offset);
+}
+
+template <>
inline int32x4_t ShiftRight(int32x4_t a, int offset) {
return vshlq_s32(a, vdupq_n_s32(-offset));
}
@@ -282,6 +292,22 @@ inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) {
return vrshlq_s16(fixed_up_x, shift_vec);
}
+template <>
+inline int32x4_t RoundingDivideByPOT(int32x4_t x, int32x4_t exponent) {
+ const int32x4_t shift_vec = vnegq_s32(exponent);
+ const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
+ const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
+ return vrshlq_s32(fixed_up_x, shift_vec);
+}
+
+template <>
+inline int16x8_t RoundingDivideByPOT(int16x8_t x, int16x8_t exponent) {
+ const int16x8_t shift_vec = vnegq_s16(exponent);
+ const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
+ const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
+ return vrshlq_s16(fixed_up_x, shift_vec);
+}
+
template <int Exponent>
struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
diff --git a/fixedpoint/fixedpoint_sse.h b/fixedpoint/fixedpoint_sse.h
index ba990f0..a1fae32 100644
--- a/fixedpoint/fixedpoint_sse.h
+++ b/fixedpoint/fixedpoint_sse.h
@@ -42,13 +42,13 @@ struct int16x8_m128i {
template <>
struct FixedPointRawTypeTraits<__m128i> {
typedef std::int32_t ScalarRawType;
- static const int kLanes = 4;
+ static constexpr int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<int16x8_m128i> {
typedef std::int16_t ScalarRawType;
- static const int kLanes = 8;
+ static constexpr int kLanes = 8;
};
template <>