aboutsummaryrefslogtreecommitdiff
path: root/fixedpoint/fixedpoint_msa.h
diff options
context:
space:
mode:
Diffstat (limited to 'fixedpoint/fixedpoint_msa.h')
-rw-r--r--fixedpoint/fixedpoint_msa.h75
1 files changed, 67 insertions, 8 deletions
diff --git a/fixedpoint/fixedpoint_msa.h b/fixedpoint/fixedpoint_msa.h
index c7a110c..b17f32a 100644
--- a/fixedpoint/fixedpoint_msa.h
+++ b/fixedpoint/fixedpoint_msa.h
@@ -25,13 +25,13 @@ namespace gemmlowp {
template <>
struct FixedPointRawTypeTraits<v4i32> {
typedef std::int32_t ScalarRawType;
- static const int kLanes = 4;
+ static constexpr int kLanes = 4;
};
template <>
struct FixedPointRawTypeTraits<v8i16> {
typedef std::int16_t ScalarRawType;
- static const int kLanes = 8;
+ static constexpr int kLanes = 8;
};
template <>
@@ -326,11 +326,71 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> {
}
};
-// TODO: possibly implement:
-// template <> v4i32 RoundingDivideByPOT(v4i32, int)
-// template <> v8i16 RoundingDivideByPOT(v8i16, int)
-// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1>
-// template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1>
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1> {
+ static v4i32 eval(v4i32 x) {
+ static_assert(-31 <= Exponent && Exponent <= -1, "");
+ // Isolate the sign bits.
+ v4i32 sign = __builtin_msa_srli_w(x, 31);
+ // Decrement the negative elements by 1 (with saturation).
+ x = __builtin_msa_subs_s_w(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srari instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srari_w(x, -Exponent);
+ }
+};
+
+template <int Exponent>
+struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1> {
+ static v8i16 eval(v8i16 x) {
+ static_assert(-15 <= Exponent && Exponent <= -1, "");
+ // Isolate the sign bits.
+ v8i16 sign = __builtin_msa_srli_h(x, 15);
+ // Decrement the negative elements by 1 (with saturation).
+ x = __builtin_msa_subs_s_h(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srari instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srari_h(x, -Exponent);
+ }
+};
+
+template <>
+inline v4i32 RoundingDivideByPOT(v4i32 x, int exponent) {
+ v4i32 e = __builtin_msa_fill_w(exponent);
+ // Isolate the sign bits.
+ v4i32 sign = __builtin_msa_srli_w(x, 31);
+ // Reset them to 0 if exponent is 0.
+ sign = __builtin_msa_min_s_w(sign, e);
+ // Decrement the negative elements by 1 (with saturation)
+ // if exponent is non-zero.
+ x = __builtin_msa_subs_s_w(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srar instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srar_w(x, e);
+}
+
+template <>
+inline v8i16 RoundingDivideByPOT(v8i16 x, int exponent) {
+ v8i16 e = __builtin_msa_fill_h(exponent);
+ // Isolate the sign bits.
+ v8i16 sign = __builtin_msa_srli_h(x, 15);
+ // Reset them to 0 if exponent is 0.
+ sign = __builtin_msa_min_s_h(sign, e);
+ // Decrement the negative elements by 1 (with saturation)
+ // if exponent is non-zero.
+ x = __builtin_msa_subs_s_h(x, sign);
+ // Arithmetic shift right with rounding.
+ // The srar instruction rounds all midpoint values towards +infinity.
+ // It will correctly round negative midpoint values as we just
+ // decremented the negative values by 1.
+ return __builtin_msa_srar_h(x, e);
+}
template <>
inline v4i32 Dup<v4i32>(std::int32_t x) {
@@ -346,7 +406,6 @@ inline v8i16 Dup<v8i16>(std::int16_t x) {
template <>
inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) {
return __builtin_msa_adds_s_h(a, b);
- return a;
}
} // end namespace gemmlowp