diff options
Diffstat (limited to 'fixedpoint/fixedpoint_msa.h')
-rw-r--r-- | fixedpoint/fixedpoint_msa.h | 75 |
1 files changed, 67 insertions, 8 deletions
diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h index c7a110c..b17f32a 100644 --- a/fixedpoint/fixedpoint_msa.h +++ b/fixedpoint/fixedpoint_msa.h @@ -25,13 +25,13 @@ namespace gemmlowp { template <> struct FixedPointRawTypeTraits<v4i32> { typedef std::int32_t ScalarRawType; - static const int kLanes = 4; + static constexpr int kLanes = 4; }; template <> struct FixedPointRawTypeTraits<v8i16> { typedef std::int16_t ScalarRawType; - static const int kLanes = 8; + static constexpr int kLanes = 8; }; template <> @@ -326,11 +326,71 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> { } }; -// TODO: possibly implement: -// template <> v4i32 RoundingDivideByPOT(v4i32, int) -// template <> v8i16 RoundingDivideByPOT(v8i16, int) -// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> -// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> +template <int Exponent> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> { + static v4i32 eval(v4i32 x) { + static_assert(-31 <= Exponent && Exponent <= -1, ""); + // Isolate the sign bits. + v4i32 sign = __builtin_msa_srli_w(x, 31); + // Decrement the negative elements by 1 (with saturation). + x = __builtin_msa_subs_s_w(x, sign); + // Arithmetic shift right with rounding. + // The srari instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srari_w(x, -Exponent); + } +}; + +template <int Exponent> +struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> { + static v8i16 eval(v8i16 x) { + static_assert(-15 <= Exponent && Exponent <= -1, ""); + // Isolate the sign bits. + v8i16 sign = __builtin_msa_srli_h(x, 15); + // Decrement the negative elements by 1 (with saturation). + x = __builtin_msa_subs_s_h(x, sign); + // Arithmetic shift right with rounding. + // The srari instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srari_h(x, -Exponent); + } +}; + +template <> +inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) { + v4i32 e = __builtin_msa_fill_w(exponent); + // Isolate the sign bits. + v4i32 sign = __builtin_msa_srli_w(x, 31); + // Reset them to 0 if exponent is 0. + sign = __builtin_msa_min_s_w(sign, e); + // Decrement the negative elements by 1 (with saturation) + // if exponent is non-zero. + x = __builtin_msa_subs_s_w(x, sign); + // Arithmetic shift right with rounding. + // The srar instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srar_w(x, e); +} + +template <> +inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) { + v8i16 e = __builtin_msa_fill_h(exponent); + // Isolate the sign bits. + v8i16 sign = __builtin_msa_srli_h(x, 15); + // Reset them to 0 if exponent is 0. + sign = __builtin_msa_min_s_h(sign, e); + // Decrement the negative elements by 1 (with saturation) + // if exponent is non-zero. + x = __builtin_msa_subs_s_h(x, sign); + // Arithmetic shift right with rounding. + // The srar instruction rounds all midpoint values towards +infinity. + // It will correctly round negative midpoint values as we just + // decremented the negative values by 1. + return __builtin_msa_srar_h(x, e); +} template <> inline v4i32 Dup<v4i32>(std::int32_t x) { @@ -346,7 +406,6 @@ inline v8i16 Dup<v8i16>(std::int16_t x) { template <> inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) { return __builtin_msa_adds_s_h(a, b); - return a; } } // end namespace gemmlowp |