diff options
Diffstat (limited to 'ruy/mul_params.h')
-rw-r--r-- | ruy/mul_params.h | 62 |
1 files changed, 33 insertions, 29 deletions
diff --git a/ruy/mul_params.h b/ruy/mul_params.h index d5aa27b..42a5700 100644 --- a/ruy/mul_params.h +++ b/ruy/mul_params.h @@ -103,14 +103,9 @@ class MulParams final { // The bias vector data, if not null. const AccumScalar* bias() const { return storage_.bias; } void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; } - // Only for non-floating-point cases. The fixed-point part of the multiplier - // by which accumulators are multiplied before being casted to the destination - // type. This is a fixed-point quantity with 0 integer bits. Since - // (as explained in the class comment) AccumScalar must be std::int32_t, - // that means that the fixed-point format is Q0.31. For example, - // a multiplier_fixedpoint value of 2^30 has the effect of multiplying - // by one half (1/2). More generally, the effect is to multiply by - // (multiplier_fixedpoint / (2^31)). + // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) + // of the multiplier by which accumulators are multiplied before being casted + // to the destination type. AccumScalar multiplier_fixedpoint() const { return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint; } @@ -132,10 +127,9 @@ class MulParams final { // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel` // and `multiplier_exponent_perchannel` are used instead. // - // This must point to a buffer of as many values as there are rows or columns - // in the destination matrix, whichever is the channels dimension. Each - // channel of the destination matrix will use the corresponding buffer element - // instead of multiplier_fixedpoint. + // This must point to a buffer of as many values as there are rows in the + // destination matrix. Each row of the destination matrix will use the + // corresponding buffer element instead of multiplier_fixedpoint. const AccumScalar* multiplier_fixedpoint_perchannel() const { return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel : nullptr; @@ -205,6 +199,16 @@ class MulParams final { detail::MulParamsStorage<AccumScalar, DstScalar> storage_; void set_perchannel(bool perchannel) { + if (storage_.perchannel == perchannel) { + return; + } + if (perchannel) { + RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0); + RUY_DCHECK_EQ(storage_.multiplier_exponent, 0); + } else { + RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr); + RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr); + } storage_.perchannel = perchannel; } }; @@ -240,25 +244,25 @@ template <typename DstScalar> struct MulParamsStorage<std::int32_t, DstScalar> final { using AccumScalar = std::int32_t; static_assert(std::is_integral<DstScalar>::value, ""); - static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, ""); + static_assert(sizeof(DstScalar) < sizeof(AccumScalar), ""); const AccumScalar* bias = nullptr; - union { - const AccumScalar* multiplier_fixedpoint_perchannel; - // Let the default multiplier be effecively a multiplication by 1, so that - // the matmul behaves as a (saturating) plain integer matmul. Unfortunately - // 1 is not exactly representable in fixedpoint with 0 integer bits, but - // using the highest representable value is a sufficiently good - // approximation: since this specialization of MulParams is for the case - // where DstScalar is at least 2x narrower than MulScalar, the values - // for which there would be a difference will get saturated anyway. - AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max(); - }; - union { - const int* multiplier_exponent_perchannel; - // See the above comment about the default value of multiplier_fixedpoint. - int multiplier_exponent = 0; - }; + // union { // This used to be a union, temporarily flattened to debug a crash + const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; + // Let the default multiplier be effecively a multiplication by 1, so that + // the matmul behaves as a (saturating) plain integer matmul. Unfortunately + // 1 is not exactly representable in fixedpoint with 0 integer bits, but + // using the highest representable value is a sufficiently good + // approximation: since this specialization of MulParams is for the case + // where DstScalar is at least 2x narrower than MulScalar, the values + // for which there would be a difference will get saturated anyway. + AccumScalar multiplier_fixedpoint = 0; + //}; + // union { // This used to be a union, temporarily flattened to debug a crash + const int* multiplier_exponent_perchannel = nullptr; + // See the above comment about the default value of multiplier_fixedpoint. + int multiplier_exponent = 0; + // }; DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest(); DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); ChannelDimension channel_dimension = ChannelDimension::kRow; |