diff options
Diffstat (limited to 'fixedpoint/fixedpoint.h')
-rw-r--r-- | fixedpoint/fixedpoint.h | 147 |
1 files changed, 122 insertions, 25 deletions
diff --git a/fixedpoint/fixedpoint.h b/fixedpoint/fixedpoint.h index e21337f..d39341b 100644 --- a/fixedpoint/fixedpoint.h +++ b/fixedpoint/fixedpoint.h @@ -50,6 +50,12 @@ struct FixedPointRawTypeTraits<std::int32_t> { static const int kLanes = 1; }; +template <> +struct FixedPointRawTypeTraits<std::int16_t> { + typedef std::int16_t ScalarRawType; + static const int kLanes = 1; +}; + // Returns a SIMD value duplicating a scalar value across all lanes. template <typename tRawType> tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { @@ -217,6 +223,50 @@ inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { return static_cast<std::int32_t>((sum + sign) / 2); } +template <> +inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { + std::int32_t a32 = a; + std::int32_t b32 = b; + std::int32_t sum = a32 + b32; + std::int32_t sign = sum >= 0 ? 1 : -1; + return static_cast<std::int16_t>((sum + sign) / 2); +} + +template <typename IntegerType> +IntegerType SaturatingAdd(IntegerType a, IntegerType b) { + static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); + return a; +} + +// So far this is only needed for int16. +template <> +inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { + std::int32_t a32 = a; + std::int32_t b32 = b; + std::int32_t sum = a32 + b32; + return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum))); +} + +// Returns a+b, saturating if the integers are 16bit or narrower, +// otherwise just a plain addition. +template <typename IntegerType, bool Is16Bit> +struct AddSaturatingIf16BitImpl { + static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } +}; +template <typename IntegerType> +struct AddSaturatingIf16BitImpl<IntegerType, true> { + static IntegerType Run(IntegerType a, IntegerType b) { + return SaturatingAdd(a, b); + } +}; +template <typename IntegerType> +IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { + using ScalarType = + typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; + return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a, + b); +} + // Returns the integer that represents the product of two fixed-point // numbers, interpreting all integers as fixed-point values in the // interval [-1, 1), rounding to the nearest value, and saturating @@ -266,14 +316,23 @@ inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; } +template <> +inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, + std::int16_t b) { + bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min(); + std::int32_t a_32(a); + std::int32_t b_32(b); + std::int32_t ab_32 = a_32 * b_32; + std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); + std::int16_t ab_x2_high16 = + static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15)); + return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16; +} + // Correctly-rounded-to-nearest division by a power-of-two. // Also known as a rounding arithmetic right shift. template <typename IntegerType> inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) { - using ScalarIntegerType = - typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; - static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value, - "Currently only supporting int32 scalar and SIMD types"); assert(exponent >= 0); assert(exponent <= 31); const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); @@ -304,14 +363,14 @@ struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { static IntegerType eval(IntegerType x) { using ScalarIntegerType = typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; - static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value, - "Currently only supporting int32 scalar and SIMD types"); const IntegerType min = - Dup<IntegerType>(std::numeric_limits<std::int32_t>::min()); + Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min()); const IntegerType max = - Dup<IntegerType>(std::numeric_limits<std::int32_t>::max()); + Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); - const std::int32_t threshold = ((1 << (31 - Exponent)) - 1); + const std::int32_t threshold = + ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); const IntegerType positive_mask = MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); const IntegerType negative_mask = @@ -425,15 +484,16 @@ class FixedPoint { static FixedPoint Zero() { return FromScalarRaw(0); } static FixedPoint One() { - return FromScalarRaw(kIntegerBits == 0 - ? ScalarRawMax() - : (ScalarRawType(1) << kFractionalBits)); + return FromScalarRaw( + kIntegerBits == 0 + ? ScalarRawMax() + : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); } static FixedPoint FromDouble(double x) { const double min_bound = static_cast<double>(ScalarRawMin()); const double max_bound = static_cast<double>(ScalarRawMax()); - return FromScalarRaw(static_cast<std::int32_t>(std::min( + return FromScalarRaw(static_cast<ScalarRawType>(std::min( std::max(round(x * static_cast<double>(1ll << kFractionalBits)), min_bound), max_bound))); @@ -555,6 +615,22 @@ bool operator!=(FixedPoint<tRawType, tIntegerBits> a, return !(a == b); } +template <typename tRawType, int tIntegerBits> +FixedPoint<tRawType, tIntegerBits> SaturatingAdd( + FixedPoint<tRawType, tIntegerBits> a, + FixedPoint<tRawType, tIntegerBits> b) { + return FixedPoint<tRawType, tIntegerBits>::FromRaw( + SaturatingAdd(a.raw(), b.raw())); +} + +template <typename tRawType, int tIntegerBits> +FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit( + FixedPoint<tRawType, tIntegerBits> a, + FixedPoint<tRawType, tIntegerBits> b) { + return FixedPoint<tRawType, tIntegerBits>::FromRaw( + AddSaturatingIf16Bit(a.raw(), b.raw())); +} + // Conversion to floating-point. template <typename tRawType, int tIntegerBits> double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { @@ -579,23 +655,41 @@ FixedPoint<tRawType, tIntegerBitsDst> Rescale( // initialized as real numbers, in a way that does not compile floating-point // arithmetic in production code, yet still checks agreement with the // floating-point expressions when asserts are enabled. +// +// The raw integer value provided is always a int32, encoding a 32-bit +// fixed-point value, regardless of the actual Scalar type. This allows +// writing generic code that applies just as well to the 32-bit and 16-bit +// cases. In the 16-bit case, the raw integer value is internally +// rounding-shifted by 16 bits to the right. +template <typename FixedPointType> +inline typename FixedPointType::ScalarRawType RescaleConstantInitializer( + std::int32_t int32_value) { + typedef typename FixedPointType::ScalarRawType ScalarRawType; + static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); + return static_cast<ScalarRawType>( + RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits)); +} #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS template <typename FixedPointType> -FixedPointType CheckedFixedPointConstant( - typename FixedPointType::ScalarRawType raw_value, double double_value) { - typedef typename FixedPointType::RawType RawType; +FixedPointType CheckedFixedPointConstant(std::int32_t raw_value, + double double_value) { const FixedPointType result = FixedPointType::FromScalarRaw(raw_value); assert(result == FixedPointType::FromDouble(double_value)); return result; } -#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ - DoubleValue) \ - (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue)) +#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ + ScalarRawInt32Value, DoubleValue) \ + (gemmlowp::CheckedFixedPointConstant<FixedPointType>( \ + gemmlowp::RescaleConstantInitializer<FixedPointType>( \ + ScalarRawInt32Value), \ + DoubleValue)) #else -#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ - DoubleValue) \ - (FixedPointType::FromScalarRaw(ScalarRawValue)) +#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ + ScalarRawInt32Value, DoubleValue) \ + (FixedPointType::FromScalarRaw( \ + gemmlowp::RescaleConstantInitializer<FixedPointType>( \ + ScalarRawInt32Value))) #endif // Implementation of exponential function. @@ -620,8 +714,9 @@ FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( F x4_over_24_plus_x3_over_6_plus_x2_over_2 = SaturatingRoundingMultiplyByPOT<-1>( ((x4_over_4 + x3) * constant_1_over_3) + x2); - return constant_term + - constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2); + return AddSaturatingIf16Bit( + constant_term, + constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); } // Returns exp(x) for x < 0. @@ -661,7 +756,7 @@ FixedPoint<tRawType, 0> exp_on_negative_values( #undef GEMMLOWP_EXP_BARREL_SHIFTER if (kIntegerBits > 5) { - static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0; + static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0; const InputF clamp = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); @@ -774,6 +869,8 @@ FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { #include "./fixedpoint_neon.h" #elif defined(GEMMLOWP_SSE4) #include "./fixedpoint_sse.h" +#elif defined(GEMMLOWP_MSA) +#include "./fixedpoint_msa.h" #endif #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ |