aboutsummaryrefslogtreecommitdiff
path: root/ruy/mul_params.h
diff options
context:
space:
mode:
Diffstat (limited to 'ruy/mul_params.h')
-rw-r--r--ruy/mul_params.h62
1 files changed, 33 insertions, 29 deletions
diff --git a/ruy/mul_params.h b/ruy/mul_params.h
index d5aa27b..42a5700 100644
--- a/ruy/mul_params.h
+++ b/ruy/mul_params.h
@@ -103,14 +103,9 @@ class MulParams final {
// The bias vector data, if not null.
const AccumScalar* bias() const { return storage_.bias; }
void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
- // Only for non-floating-point cases. The fixed-point part of the multiplier
- // by which accumulators are multiplied before being casted to the destination
- // type. This is a fixed-point quantity with 0 integer bits. Since
- // (as explained in the class comment) AccumScalar must be std::int32_t,
- // that means that the fixed-point format is Q0.31. For example,
- // a multiplier_fixedpoint value of 2^30 has the effect of multiplying
- // by one half (1/2). More generally, the effect is to multiply by
- // (multiplier_fixedpoint / (2^31)).
+ // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
+ // of the multiplier by which accumulators are multiplied before being casted
+ // to the destination type.
AccumScalar multiplier_fixedpoint() const {
return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
}
@@ -132,10 +127,9 @@ class MulParams final {
// `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
// and `multiplier_exponent_perchannel` are used instead.
//
- // This must point to a buffer of as many values as there are rows or columns
- // in the destination matrix, whichever is the channels dimension. Each
- // channel of the destination matrix will use the corresponding buffer element
- // instead of multiplier_fixedpoint.
+ // This must point to a buffer of as many values as there are rows in the
+ // destination matrix. Each row of the destination matrix will use the
+ // corresponding buffer element instead of multiplier_fixedpoint.
const AccumScalar* multiplier_fixedpoint_perchannel() const {
return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
: nullptr;
@@ -205,6 +199,16 @@ class MulParams final {
detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
void set_perchannel(bool perchannel) {
+ if (storage_.perchannel == perchannel) {
+ return;
+ }
+ if (perchannel) {
+ RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0);
+ RUY_DCHECK_EQ(storage_.multiplier_exponent, 0);
+ } else {
+ RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr);
+ RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr);
+ }
storage_.perchannel = perchannel;
}
};
@@ -240,25 +244,25 @@ template <typename DstScalar>
struct MulParamsStorage<std::int32_t, DstScalar> final {
using AccumScalar = std::int32_t;
static_assert(std::is_integral<DstScalar>::value, "");
- static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, "");
+ static_assert(sizeof(DstScalar) < sizeof(AccumScalar), "");
const AccumScalar* bias = nullptr;
- union {
- const AccumScalar* multiplier_fixedpoint_perchannel;
- // Let the default multiplier be effecively a multiplication by 1, so that
- // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
- // 1 is not exactly representable in fixedpoint with 0 integer bits, but
- // using the highest representable value is a sufficiently good
- // approximation: since this specialization of MulParams is for the case
- // where DstScalar is at least 2x narrower than MulScalar, the values
- // for which there would be a difference will get saturated anyway.
- AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max();
- };
- union {
- const int* multiplier_exponent_perchannel;
- // See the above comment about the default value of multiplier_fixedpoint.
- int multiplier_exponent = 0;
- };
+ // union { // This used to be a union, temporarily flattened to debug a crash
+ const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
+ // Let the default multiplier be effecively a multiplication by 1, so that
+ // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
+ // 1 is not exactly representable in fixedpoint with 0 integer bits, but
+ // using the highest representable value is a sufficiently good
+ // approximation: since this specialization of MulParams is for the case
+ // where DstScalar is at least 2x narrower than MulScalar, the values
+ // for which there would be a difference will get saturated anyway.
+ AccumScalar multiplier_fixedpoint = 0;
+ //};
+ // union { // This used to be a union, temporarily flattened to debug a crash
+ const int* multiplier_exponent_perchannel = nullptr;
+ // See the above comment about the default value of multiplier_fixedpoint.
+ int multiplier_exponent = 0;
+ // };
DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
ChannelDimension channel_dimension = ChannelDimension::kRow;