// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // fixedpoint.h: fixed-point arithmetic, with basic operations and // a few math functions such as tanh. #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ #include #include #include #include #include #include "../internal/detect_platform.h" namespace gemmlowp { // Part 1: Low-level integer-arithmetic primitives. // The implementations here are generic implementations valid for // scalar types (e.g. std::int32_t). Architecture-specific SIMD types // (e.g. NEON int32x4_t) may be supported by providing // specializations for them in separate files. // // The purpose of these primitives is two-fold: // - They will be used to implement higher-level fixed-point // abstractions, namely the FixedPoint class and its arithmetic // operators. // - They will be directly used to implement some more involved // fixed-point computations, e.g. the fixed-point implementation // of math functions such as tanh. // Some compile-time traits around raw types to handle SIMD aspects: // number of lanes, underlying scalar type. template struct FixedPointRawTypeTraits {}; template <> struct FixedPointRawTypeTraits { typedef std::int32_t ScalarRawType; static constexpr int kLanes = 1; }; template <> struct FixedPointRawTypeTraits { typedef std::int16_t ScalarRawType; static constexpr int kLanes = 1; }; // Returns a SIMD value duplicating a scalar value across all lanes. template tRawType Dup(typename FixedPointRawTypeTraits::ScalarRawType x) { return x; } // Plain bit-wise AND template tIntegerType BitAnd(tIntegerType a, tIntegerType b) { return a & b; } // Plain bit-wise OR template tIntegerType BitOr(tIntegerType a, tIntegerType b) { return a | b; } // Plain bit-wise XOR template tIntegerType BitXor(tIntegerType a, tIntegerType b) { return a ^ b; } // Plain bit-wise NOT template tIntegerType BitNot(tIntegerType a) { return ~a; } // Integer addition. Not saturating. Overflow is undefined behavior. template tIntegerType Add(tIntegerType a, tIntegerType b) { return a + b; } // Integer multiplication. Not saturating. Overflow is undefined behavior. template tIntegerType Mul(tIntegerType a, tIntegerType b) { return a * b; } // Integer subtraction. Not saturating. Overflow is undefined behavior. template tIntegerType Sub(tIntegerType a, tIntegerType b) { return a - b; } // Integer unary negative. Not saturating. Overflow is undefined behavior. template tIntegerType Neg(tIntegerType a) { return -a; } // Integer arithmetic left-shift, equivalent to multiplying with a power of two. // Negative values are OK. In case of overflow, no Undefined // Behavior, but the results are implementation-defined (in practice, // they currently are saturated, but we make no commitment to that). The idea // is that the caller will want to implement the overflowing cases with // saturation with compare-and-mask, so we don't care about the results // in the overflow case, we just want to avoid undefined behavior. // // tIntegerType may be int32 or any narrower signed type. template tIntegerType ShiftLeft(tIntegerType a, OffsetType offset) { const std::int64_t wide_a = static_cast(a); const std::int64_t wide_shifted = wide_a * (1 << offset); const auto min = std::numeric_limits::min(); const auto max = std::numeric_limits::max(); return wide_shifted < min ? min : wide_shifted > max ? max : static_cast(wide_shifted); } // Integer arithmetic right-shift. Not rounding. // Relying on implementation-defined, but in-practice-consistent, // C++ compiler behavior. template tIntegerType ShiftRight(tIntegerType a, int offset) { return a >> offset; } // Each bit of the result is set to the corresponding bit of either then_val or // else_val depending on whether the corresponding bit of if_mask is set. // Equivalent to the VBSL instruction in ARM NEON. template tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, tIntegerType else_val) { return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); } // For each input scalar, the corresponding bits of the result are set if the // input scalar is non-zero. template tIntegerType MaskIfNonZero(tIntegerType a) { static constexpr tIntegerType zero = 0; return a ? BitNot(zero) : zero; } // For each input scalar, the corresponding bits of the result are set if the // input scalar is zero. template tIntegerType MaskIfZero(tIntegerType a) { return MaskIfNonZero(!a); } // For each pair of input scalars, the corresponding bits of the result are // set if the input scalars are equal. template tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { return MaskIfNonZero(a == b); } // For each pair of input scalars, the corresponding bits of the result are // set if the input scalars are not equal. template tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { return MaskIfNonZero(a != b); } // For each pair of input scalars, the corresponding bits of the result are // set if the input scalars a, b satisfy a > b. template tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { return MaskIfNonZero(a > b); } // For each pair of input scalars, the corresponding bits of the result are // set if the input scalars a, b satisfy a >= b. template tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { return MaskIfNonZero(a >= b); } // For each pair of input scalars, the corresponding bits of the result are // set if the input scalars a, b satisfy a < b. template tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { return MaskIfNonZero(a < b); } // For each pair of input scalars, the corresponding bits of the result are // set if the input scalars a, b satisfy a <= b. template tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { return MaskIfNonZero(a <= b); } // Returns true if all of the input scalars are nonzero. // This function may currently assume that each of the input scalars has either // all or none of its bits set. Otherwise, its behavior is currently undefined. template bool All(tIntegerType a) { return a; } // Returns true if any of the input scalars are nonzero. // This function may currently assume that each of the input scalars has either // all or none of its bits set. Otherwise, its behavior is currently undefined. template bool Any(tIntegerType a) { return a; } // Returns (a+b)/2, rounded to the nearest integer. // Equivalent to VRHADD in the ARM NEON instruction set. template IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { static_assert(std::is_same::value, "unimplemented"); (void)b; return a; } template <> inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { std::int64_t a64 = a; std::int64_t b64 = b; std::int64_t sum = a64 + b64; std::int64_t sign = sum >= 0 ? 1 : -1; return static_cast((sum + sign) / 2); } template <> inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) { std::int32_t a32 = a; std::int32_t b32 = b; std::int32_t sum = a32 + b32; std::int32_t sign = sum >= 0 ? 1 : -1; return static_cast((sum + sign) / 2); } template IntegerType SaturatingAdd(IntegerType a, IntegerType b) { static_assert(std::is_same::value, "unimplemented"); (void)b; return a; } // So far this is only needed for int16. template <> inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) { std::int32_t a32 = a; std::int32_t b32 = b; std::int32_t sum = a32 + b32; return static_cast( std::min(static_cast(32767), std::max(static_cast(-32768), sum))); } template <> inline std::int8_t SaturatingAdd(std::int8_t a, std::int8_t b) { std::int16_t a16 = a; std::int16_t b16 = b; std::int16_t sum = a16 + b16; return static_cast(std::min( static_cast(std::numeric_limits::max()), std::max(static_cast(std::numeric_limits::min()), sum))); } // Returns a+b, saturating if the integers are 16bit or narrower, // otherwise just a plain addition. template struct AddSaturatingIf16BitImpl { static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); } }; template struct AddSaturatingIf16BitImpl { static IntegerType Run(IntegerType a, IntegerType b) { return SaturatingAdd(a, b); } }; template IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) { using ScalarType = typename FixedPointRawTypeTraits::ScalarRawType; return AddSaturatingIf16BitImpl::Run(a, b); } // Returns the integer that represents the product of two fixed-point // numbers, interpreting all integers as fixed-point values in the // interval [-1, 1), rounding to the nearest value, and saturating // -1 * -1 to the maximum value (since 1 is not in the half-open // interval [-1, 1)). // // [The explanation below specializes to std::int32_t for example purpose.] // // The mapping between IntegerType and the interval [-1, 1) is unique and // implied by IntegerType, which is assumed to be signed. For example, // for IntegerType==std::int32_t, the mapping is // real_value = integer_value / 2^31. // So in this case, and leaving aside rounding and saturating, this // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to // (a * b) / 2^31. // // The 'doubling' part in the name of this function comes from the fact that // this operation is very close to a "multiply-high" operation, keeping only // the top half bits, except that that would be effectively computing // (a * b) / 2^32, // so here we are computing 2x that, since // 1/2^31 = 2 * 1/2^32. // The idea is to use all of the available 32 bits in the destination int32 // value. // // [End of the explanation specializing to int32.] // // This is equivalent to the VQRDMULH instruction in ARM NEON. template IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { static_assert(std::is_same::value, "unimplemented"); (void)b; return a; } // This function implements the same computation as the ARMv7 NEON VQRDMULH // instruction. template <> inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b) { bool overflow = a == b && a == std::numeric_limits::min(); std::int64_t a_64(a); std::int64_t b_64(b); std::int64_t ab_64 = a_64 * b_64; std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); std::int32_t ab_x2_high32 = static_cast((ab_64 + nudge) / (1ll << 31)); return overflow ? std::numeric_limits::max() : ab_x2_high32; } template <> inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a, std::int16_t b) { bool overflow = a == b && a == std::numeric_limits::min(); std::int32_t a_32(a); std::int32_t b_32(b); std::int32_t ab_32 = a_32 * b_32; std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14)); std::int16_t ab_x2_high16 = static_cast((ab_32 + nudge) / (1 << 15)); return overflow ? std::numeric_limits::max() : ab_x2_high16; } // Correctly-rounded-to-nearest division by a power-of-two. // Also known as a rounding arithmetic right shift. template inline IntegerType RoundingDivideByPOT(IntegerType x, ExponentType exponent) { assert(exponent >= 0); assert(exponent <= 31); const IntegerType mask = Dup((1ll << exponent) - 1); const IntegerType zero = Dup(0); const IntegerType one = Dup(1); const IntegerType remainder = BitAnd(x, mask); const IntegerType threshold = Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); return Add(ShiftRight(x, exponent), BitAnd(MaskIfGreaterThan(remainder, threshold), one)); } // Returns the product of a run-time integer value by a compile-time power // of two, with either a positive exponent (equivalent to an arithmetic // left shift, saturating) or a negative exponent (equivalent to an arithmetic // right shift, rounding to nearest). template 0 ? 1 : Exponent < 0 ? -1 : 0)> struct ImplSaturatingRoundingMultiplyByPOT {}; template struct ImplSaturatingRoundingMultiplyByPOT { static IntegerType eval(IntegerType x) { return x; } }; template struct ImplSaturatingRoundingMultiplyByPOT { static IntegerType eval(IntegerType x) { using ScalarIntegerType = typename FixedPointRawTypeTraits::ScalarRawType; const IntegerType min = Dup(std::numeric_limits::min()); const IntegerType max = Dup(std::numeric_limits::max()); const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); const std::int32_t threshold = ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1); const IntegerType positive_mask = MaskIfGreaterThan(x, Dup(threshold)); const IntegerType negative_mask = MaskIfLessThan(x, Dup(-threshold)); IntegerType result = ShiftLeft(x, Exponent); result = SelectUsingMask(positive_mask, max, result); result = SelectUsingMask(negative_mask, min, result); return result; } }; template struct ImplSaturatingRoundingMultiplyByPOT { static IntegerType eval(IntegerType x) { return RoundingDivideByPOT(x, -Exponent); } }; template IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { return ImplSaturatingRoundingMultiplyByPOT::eval(x); } // Part 2: the FixedPoint class. // A FixedPoint object represents a fixed-point value stored in the underlying // integer type tRawType, if tRawType is a plain scalar integer type. // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which // case a FixedPoint object represents a corresponding SIMD vector of fixed // point values. // // tIntegerBits describes the range of the fixed-point format: if // tIntegerBits == m then the range of representable values is the half-open // interval [-2^m; 2^m) where the open boundary on the right side means that // 2^m is not representable (how close the maximum representable value is to // it, depends on bit-depth of tRawType). // // In "Q format notation", // https://en.wikipedia.org/wiki/Q_(number_format) // we are describing the format // Qm.n // where // m = tIntegerBits // and // n = NumberOfBits(tRawType) - (m + 1) // Note that the (m + 1) in the above line is because we adopt the convention // that we count the integer bits exclusively of the sign bit; so (m + 1) is // the total number of integer bits inclusive of the sign bit. // // Accordingly, the number of integral representable values in our range // [-2^m ; 2^m) // is equal to 2^(m+1). template class FixedPoint { public: typedef tRawType RawType; typedef FixedPointRawTypeTraits RawTypeTraits; typedef typename RawTypeTraits::ScalarRawType ScalarRawType; static constexpr int kTotalBits = 8 * sizeof(ScalarRawType); static constexpr int kIntegerBits = tIntegerBits; static constexpr int kFractionalBits = kTotalBits - 1 - kIntegerBits; static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, "bad IntegerBits"); typedef FixedPoint ScalarFixedPointType; static const ScalarRawType ScalarRawMin() { return std::numeric_limits::min(); } static const ScalarRawType ScalarRawMax() { return std::numeric_limits::max(); } static const ScalarRawType RawMin() { return VectorFromScalar(ScalarRawMin()); } static const ScalarRawType RawMax() { return VectorFromScalar(ScalarRawMax()); } static FixedPoint FromRaw(RawType x) { FixedPoint retval; retval.raw() = x; return retval; } static FixedPoint FromScalarRaw(ScalarRawType x) { FixedPoint retval; retval.raw() = Dup(x); return retval; } static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { return FromScalarRaw(x.raw()); } template static FixedPoint ConstantPOT() { static constexpr int kOffset = kFractionalBits + Exponent; static_assert( kOffset < 31, "Constant not exactly representable in this fixed-point format"); return FromScalarRaw(ScalarRawType(1) << kOffset); } static FixedPoint Zero() { return FromScalarRaw(0); } static FixedPoint One() { return FromScalarRaw( kIntegerBits == 0 ? ScalarRawMax() : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits))); } static FixedPoint FromDouble(double x) { const double min_bound = static_cast(ScalarRawMin()); const double max_bound = static_cast(ScalarRawMax()); return FromScalarRaw(static_cast(std::min( std::max(round(x * static_cast(1ll << kFractionalBits)), min_bound), max_bound))); } RawType raw() const { return i_; } RawType& raw() { return i_; } private: RawType i_; }; // Part 3: implementation of arithmetic operators for the // FixedPoint class, and a few related functions. // A FixedPoint multiplication is just a // SaturatingRoundingDoublingHighMul operation on the underlying // raw integer values. The IntegerBits simply add up, as is obvious // from the fact that the range is [-2^IntegerBits, 2^IntegerBits). template FixedPoint operator*( FixedPoint a, FixedPoint b) { FixedPoint c; c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); return c; } // Tweaking IntegerBits gives exact multiplication by a power of two. template FixedPoint ExactMulByPot( FixedPoint a) { FixedPoint c; c.raw() = a.raw(); return c; } // If we want to leave IntegerBits fixed, then multiplication // by a power of two has to be saturating/rounding, not exact anymore. template FixedPoint SaturatingRoundingMultiplyByPOT( FixedPoint a) { return FixedPoint::FromRaw( SaturatingRoundingMultiplyByPOT(a.raw())); } // Generic arithmetic operators. #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ template \ FixedPoint FuncName( \ FixedPoint a) { \ return FixedPoint::FromRaw(ImplFuncName(a.raw())); \ } #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ template \ FixedPoint FuncName( \ FixedPoint a, \ FixedPoint b) { \ return FixedPoint::FromRaw( \ ImplFuncName(a.raw(), b.raw())); \ } MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) #undef MAKE_FIXEDPOINT_UNARY_FUNC #undef MAKE_FIXEDPOINT_BINARY_FUNC #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ template \ tRawType FuncName(FixedPoint a) { \ return FuncName(a.raw()); \ } #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ template \ tRawType FuncName(FixedPoint a, \ FixedPoint b) { \ return FuncName(a.raw(), b.raw()); \ } MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW template FixedPoint SelectUsingMask( tRawType if_mask, FixedPoint then_val, FixedPoint else_val) { return FixedPoint::FromRaw( SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); } template bool operator==(FixedPoint a, FixedPoint b) { return All(MaskIfEqual(a.raw(), b.raw())); } template bool operator!=(FixedPoint a, FixedPoint b) { return !(a == b); } template FixedPoint SaturatingAdd( FixedPoint a, FixedPoint b) { return FixedPoint::FromRaw( SaturatingAdd(a.raw(), b.raw())); } template FixedPoint AddSaturatingIf16Bit( FixedPoint a, FixedPoint b) { return FixedPoint::FromRaw( AddSaturatingIf16Bit(a.raw(), b.raw())); } // Conversion to floating-point. template double ToDouble(FixedPoint x) { static_assert(FixedPointRawTypeTraits::kLanes == 1, "not applicable to SIMD types"); typedef FixedPoint F; return x.raw() / static_cast(1ll << F::kFractionalBits); } // Rescale changes the number of IntegerBits and updates the underlying // raw integer value accordingly. template FixedPoint Rescale( FixedPoint x) { static constexpr int kExponent = tIntegerBitsSrc - tIntegerBitsDst; FixedPoint result; result.raw() = SaturatingRoundingMultiplyByPOT(x.raw()); return result; } // CheckedFixedPointConstant allows to specify fixed-point constants // initialized as real numbers, in a way that does not compile floating-point // arithmetic in production code, yet still checks agreement with the // floating-point expressions when asserts are enabled. // // The raw integer value provided is always a int32, encoding a 32-bit // fixed-point value, regardless of the actual Scalar type. This allows // writing generic code that applies just as well to the 32-bit and 16-bit // cases. In the 16-bit case, the raw integer value is internally // rounding-shifted by 16 bits to the right. template inline typename FixedPointType::ScalarRawType RescaleConstantInitializer( std::int32_t int32_value) { typedef typename FixedPointType::ScalarRawType ScalarRawType; static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType); return static_cast( RoundingDivideByPOT(int32_value, 32 - ScalarTypeBits)); } #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS template FixedPointType CheckedFixedPointConstant(std::int32_t raw_value, double double_value) { const FixedPointType result = FixedPointType::FromScalarRaw(raw_value); assert(result == FixedPointType::FromDouble(double_value)); return result; } #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ ScalarRawInt32Value, DoubleValue) \ (gemmlowp::CheckedFixedPointConstant( \ gemmlowp::RescaleConstantInitializer( \ ScalarRawInt32Value), \ DoubleValue)) #else #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, \ ScalarRawInt32Value, DoubleValue) \ (FixedPointType::FromScalarRaw( \ gemmlowp::RescaleConstantInitializer( \ ScalarRawInt32Value))) #endif // Implementation of exponential function. // Returns exp(x) for x in [-1/4, 0). template FixedPoint exp_on_interval_between_negative_one_quarter_and_0_excl( FixedPoint a) { typedef FixedPoint F; const F constant_term = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); const F constant_1_over_3 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); // We're evaluating a Taylor expansion around -1/8, so we do the change of // variable: x = a + 1/8. // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. F x = a + F::template ConstantPOT<-3>(); F x2 = x * x; F x3 = x2 * x; F x4 = x2 * x2; F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); F x4_over_24_plus_x3_over_6_plus_x2_over_2 = SaturatingRoundingMultiplyByPOT<-1>( ((x4_over_4 + x3) * constant_1_over_3) + x2); return AddSaturatingIf16Bit( constant_term, constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); } // Returns exp(x) for x < 0. template FixedPoint exp_on_negative_values( FixedPoint a) { typedef FixedPoint InputF; typedef FixedPoint ResultF; static constexpr int kFractionalBits = InputF::kFractionalBits; static constexpr int kIntegerBits = InputF::kIntegerBits; const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); InputF mask = kOneQuarter - InputF::FromScalarRaw(1); InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( Rescale<0>(a_mod_quarter_minus_one_quarter)); tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ if (kIntegerBits > Exponent) { \ const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ static constexpr int kShiftAmount = \ kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ result = SelectUsingMask( \ MaskIfNonZero(BitAnd(remainder, Dup(1 << kShiftAmount))), \ result * kMultiplier, result); \ } // Constants below are Q0 representations of negative exp fractionals: GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); // exp(-1/4) GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); // exp(-1/2) GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); // exp(-1) GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); // exp(-2) GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); // exp(-4) GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); // exp(-8) GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); // exp(-16) #undef GEMMLOWP_EXP_BARREL_SHIFTER static constexpr int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0; if (kIntegerBits > 5) { const InputF clamp = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << clampB), -32.0); result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); } result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); return result; } // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)). // Returns (1 - x) / (1 + x) for x in (0, 1). template FixedPoint one_minus_x_over_one_plus_x_for_x_in_0_1( FixedPoint a) { typedef FixedPoint F0; typedef FixedPoint F2; F0 half_denominator = RoundingHalfSum(a, F0::One()); // Newton-Raphson division // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division // Refer to that page for the logic behind the 48/17 and 32/17 constants. const F2 constant_48_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); const F2 constant_neg_32_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; for (int i = 0; i < 3; i++) { F2 half_denominator_times_x = half_denominator * x; F2 one_minus_half_denominator_times_x = F2::One() - half_denominator_times_x; x = x + Rescale<2>(x * one_minus_half_denominator_times_x); } return Rescale<0>(x - F2::One()); } // Returns -tanh(x) for x < 0. template FixedPoint neg_tanh_on_negative_values( FixedPoint a) { return one_minus_x_over_one_plus_x_for_x_in_0_1( exp_on_negative_values(ExactMulByPot<1>(a))); } // Returns tanh(x) for any x. template FixedPoint tanh(FixedPoint a) { typedef FixedPoint InputF; typedef FixedPoint ResultF; tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); tRawType mask_if_zero = MaskIfZero(a); InputF n = SelectUsingMask(mask_if_negative, a, -a); ResultF t = neg_tanh_on_negative_values(n); return SelectUsingMask(mask_if_zero, ResultF::Zero(), SelectUsingMask(mask_if_negative, -t, t)); } // Implementation of logistic function. // Returns 1 / (1 + x) for x in (0, 1). template FixedPoint one_over_one_plus_x_for_x_in_0_1( FixedPoint a) { typedef FixedPoint F0; typedef FixedPoint F2; F0 half_denominator = RoundingHalfSum(a, F0::One()); // Newton-Raphson division // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division // Refer to that page for the logic behind the 48/17 and 32/17 constants. const F2 constant_48_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); const F2 constant_neg_32_over_17 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; for (int i = 0; i < 3; i++) { F2 half_denominator_times_x = half_denominator * x; F2 one_minus_half_denominator_times_x = F2::One() - half_denominator_times_x; x = x + Rescale<2>(x * one_minus_half_denominator_times_x); } return Rescale<0>(ExactMulByPot<-1>(x)); } // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. template FixedPoint logistic_on_positive_values( FixedPoint a) { return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); } // Returns logistic(x) = 1 / (1 + exp(-x)) for any x. template FixedPoint logistic(FixedPoint a) { typedef FixedPoint InputF; typedef FixedPoint ResultF; tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero()); tRawType mask_if_zero = MaskIfZero(a); InputF abs_input = SelectUsingMask(mask_if_positive, a, -a); ResultF result_if_positive = logistic_on_positive_values(abs_input); ResultF result_if_negative = ResultF::One() - result_if_positive; const ResultF one_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5); return SelectUsingMask(mask_if_zero, one_half, SelectUsingMask(mask_if_positive, result_if_positive, result_if_negative)); } } // end namespace gemmlowp #ifdef GEMMLOWP_NEON #include "./fixedpoint_neon.h" #elif defined(GEMMLOWP_AVX2) #include "./fixedpoint_avx.h" #elif defined(GEMMLOWP_SSE4) #include "./fixedpoint_sse.h" #elif defined(GEMMLOWP_MSA) #include "./fixedpoint_msa.h" #elif defined(GEMMLOWP_WASMSIMD) #include "./fixedpoint_wasmsimd.h" #endif #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_