aboutsummaryrefslogtreecommitdiff
path: root/fixedpoint/fixedpoint.h
diff options
context:
space:
mode:
Diffstat (limited to 'fixedpoint/fixedpoint.h')
-rw-r--r--fixedpoint/fixedpoint.h147
1 files changed, 122 insertions, 25 deletions
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h
index e21337f..d39341b 100644
--- a/fixedpoint/fixedpoint.h
+++ b/fixedpoint/fixedpoint.h
@@ -50,6 +50,12 @@ struct FixedPointRawTypeTraits<std::int32_t> {
static const int kLanes = 1;
};
+template <>
+struct FixedPointRawTypeTraits<std::int16_t> {
+ typedef std::int16_t ScalarRawType;
+ static const int kLanes = 1;
+};
+
// Returns a SIMD value duplicating a scalar value across all lanes.
template <typename tRawType>
tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
@@ -217,6 +223,50 @@ inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
return static_cast<std::int32_t>((sum + sign) / 2);
}
+template <>
+inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
+ std::int32_t a32 = a;
+ std::int32_t b32 = b;
+ std::int32_t sum = a32 + b32;
+ std::int32_t sign = sum >= 0 ? 1 : -1;
+ return static_cast<std::int16_t>((sum + sign) / 2);
+}
+
+template <typename IntegerType>
+IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
+ static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
+ return a;
+}
+
+// So far this is only needed for int16.
+template <>
+inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
+ std::int32_t a32 = a;
+ std::int32_t b32 = b;
+ std::int32_t sum = a32 + b32;
+ return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
+}
+
+// Returns a+b, saturating if the integers are 16bit or narrower,
+// otherwise just a plain addition.
+template <typename IntegerType, bool Is16Bit>
+struct AddSaturatingIf16BitImpl {
+ static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
+};
+template <typename IntegerType>
+struct AddSaturatingIf16BitImpl<IntegerType, true> {
+ static IntegerType Run(IntegerType a, IntegerType b) {
+ return SaturatingAdd(a, b);
+ }
+};
+template <typename IntegerType>
+IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
+ using ScalarType =
+ typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
+ return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
+ b);
+}
+
// Returns the integer that represents the product of two fixed-point
// numbers, interpreting all integers as fixed-point values in the
// interval [-1, 1), rounding to the nearest value, and saturating
@@ -266,14 +316,23 @@ inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
}
+template <>
+inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
+ std::int16_t b) {
+ bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
+ std::int32_t a_32(a);
+ std::int32_t b_32(b);
+ std::int32_t ab_32 = a_32 * b_32;
+ std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
+ std::int16_t ab_x2_high16 =
+ static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
+ return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
+}
+
// Correctly-rounded-to-nearest division by a power-of-two.
// Also known as a rounding arithmetic right shift.
template <typename IntegerType>
inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
- using ScalarIntegerType =
- typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
- static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
- "Currently only supporting int32 scalar and SIMD types");
assert(exponent >= 0);
assert(exponent <= 31);
const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
@@ -304,14 +363,14 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
static IntegerType eval(IntegerType x) {
using ScalarIntegerType =
typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
- static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
- "Currently only supporting int32 scalar and SIMD types");
const IntegerType min =
- Dup<IntegerType>(std::numeric_limits<std::int32_t>::min());
+ Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
const IntegerType max =
- Dup<IntegerType>(std::numeric_limits<std::int32_t>::max());
+ Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
+ const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
- const std::int32_t threshold = ((1 << (31 - Exponent)) - 1);
+ const std::int32_t threshold =
+ ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
const IntegerType positive_mask =
MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
const IntegerType negative_mask =
@@ -425,15 +484,16 @@ class FixedPoint {
static FixedPoint Zero() { return FromScalarRaw(0); }
static FixedPoint One() {
- return FromScalarRaw(kIntegerBits == 0
- ? ScalarRawMax()
- : (ScalarRawType(1) << kFractionalBits));
+ return FromScalarRaw(
+ kIntegerBits == 0
+ ? ScalarRawMax()
+ : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
}
static FixedPoint FromDouble(double x) {
const double min_bound = static_cast<double>(ScalarRawMin());
const double max_bound = static_cast<double>(ScalarRawMax());
- return FromScalarRaw(static_cast<std::int32_t>(std::min(
+ return FromScalarRaw(static_cast<ScalarRawType>(std::min(
std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
min_bound),
max_bound)));
@@ -555,6 +615,22 @@ bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
return !(a == b);
}
+template <typename tRawType, int tIntegerBits>
+FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
+ FixedPoint<tRawType, tIntegerBits> a,
+ FixedPoint<tRawType, tIntegerBits> b) {
+ return FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ SaturatingAdd(a.raw(), b.raw()));
+}
+
+template <typename tRawType, int tIntegerBits>
+FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
+ FixedPoint<tRawType, tIntegerBits> a,
+ FixedPoint<tRawType, tIntegerBits> b) {
+ return FixedPoint<tRawType, tIntegerBits>::FromRaw(
+ AddSaturatingIf16Bit(a.raw(), b.raw()));
+}
+
// Conversion to floating-point.
template <typename tRawType, int tIntegerBits>
double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
@@ -579,23 +655,41 @@ FixedPoint<tRawType, tIntegerBitsDst> Rescale(
// initialized as real numbers, in a way that does not compile floating-point
// arithmetic in production code, yet still checks agreement with the
// floating-point expressions when asserts are enabled.
+//
+// The raw integer value provided is always a int32, encoding a 32-bit
+// fixed-point value, regardless of the actual Scalar type. This allows
+// writing generic code that applies just as well to the 32-bit and 16-bit
+// cases. In the 16-bit case, the raw integer value is internally
+// rounding-shifted by 16 bits to the right.
+template <typename FixedPointType>
+inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
+ std::int32_t int32_value) {
+ typedef typename FixedPointType::ScalarRawType ScalarRawType;
+ static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
+ return static_cast<ScalarRawType>(
+ RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
+}
#ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
template <typename FixedPointType>
-FixedPointType CheckedFixedPointConstant(
- typename FixedPointType::ScalarRawType raw_value, double double_value) {
- typedef typename FixedPointType::RawType RawType;
+FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
+ double double_value) {
const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
assert(result == FixedPointType::FromDouble(double_value));
return result;
}
-#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
- DoubleValue) \
- (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue))
+#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
+ ScalarRawInt32Value, DoubleValue) \
+ (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \
+ gemmlowp::RescaleConstantInitializer<FixedPointType>( \
+ ScalarRawInt32Value), \
+ DoubleValue))
#else
-#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
- DoubleValue) \
- (FixedPointType::FromScalarRaw(ScalarRawValue))
+#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \
+ ScalarRawInt32Value, DoubleValue) \
+ (FixedPointType::FromScalarRaw( \
+ gemmlowp::RescaleConstantInitializer<FixedPointType>( \
+ ScalarRawInt32Value)))
#endif
// Implementation of exponential function.
@@ -620,8 +714,9 @@ FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
SaturatingRoundingMultiplyByPOT<-1>(
((x4_over_4 + x3) * constant_1_over_3) + x2);
- return constant_term +
- constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2);
+ return AddSaturatingIf16Bit(
+ constant_term,
+ constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
}
// Returns exp(x) for x < 0.
@@ -661,7 +756,7 @@ FixedPoint<tRawType, 0> exp_on_negative_values(
#undef GEMMLOWP_EXP_BARREL_SHIFTER
if (kIntegerBits > 5) {
- static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0;
+ static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
const InputF clamp =
GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
@@ -774,6 +869,8 @@ FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
#include "./fixedpoint_neon.h"
#elif defined(GEMMLOWP_SSE4)
#include "./fixedpoint_sse.h"
+#elif defined(GEMMLOWP_MSA)
+#include "./fixedpoint_msa.h"
#endif
#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_