summaryrefslogtreecommitdiff
path: root/native/lang_id/common/embedding-network.cc
blob: 49c9ca024e6cb8bce70a9d9c7b633109b02edff7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "lang_id/common/embedding-network.h"

#include <vector>

#include "lang_id/common/lite_base/integral-types.h"
#include "lang_id/common/lite_base/logging.h"

namespace libtextclassifier3 {
namespace mobile {
namespace {

void CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) {
  SAFTM_CHECK_EQ(static_cast<int>(QuantizationType::NONE),
                 static_cast<int>(matrix.quant_type))
      << "Quantization not allowed here";
}

int GetMatrixRowSizeInBytes(const EmbeddingNetworkParams::Matrix &matrix) {
  int cols = matrix.cols;
  QuantizationType quant_type = matrix.quant_type;
  switch (quant_type) {
    case QuantizationType::NONE:
      return cols * sizeof(float);
    case QuantizationType::UINT8:
      return cols * sizeof(uint8);
    case QuantizationType::UINT4:
      SAFTM_DCHECK_EQ(cols % 2, 0) << "UINT4 with odd #cols = " << cols;
      return cols / 2;
    case QuantizationType::FLOAT16:
      return cols * sizeof(float16);
    default:
      SAFTM_LOG(FATAL) << "Unknown quant type: "
                       << static_cast<int>(quant_type);
  }
}

// Computes y = weights * Relu(x) + b where Relu is optionally applied.
//
// weights and b are the weight matrix, respectively the bias vector of a neural
// network layer.
//
// Note: in the research literature, usually Relu (the activation function) is
// the last part of a neural layer.  From that perspective, this function
// computes the Relu part of the previous layer (if any) and next the first half
// (the computation of the state) for the current layer.
//
// Note: weights is expected to be the transposed version of the real weight
// matrix.  Hence, instead of computing a linear combination of the columns of
// weights, we compute a linear combination of its rows; but we are mindful that
// these rows are the columns of the original matrix, hence the name
// weights_col_i in the code.
void SparseReluProductPlusBias(bool apply_relu,
                               const EmbeddingNetworkParams::Matrix &weights,
                               const EmbeddingNetworkParams::Matrix &b,
                               const std::vector<float> &x,
                               std::vector<float> *y) {
  // Initialize y to b.  b is a column matrix (i.e., nb.cols == 1); we already
  // CHECK-ed that the EmbeddingNetwork constructor.
  const float *b_start = reinterpret_cast<const float *>(b.elements);
  SAFTM_DCHECK_EQ(b.cols, 1);
  y->assign(b_start, b_start + b.rows);

  float *const y_data = y->data();
  const int y_size = y->size();
  SAFTM_CHECK_EQ(weights.cols, y_size);
  const int x_size = x.size();
  SAFTM_CHECK_EQ(weights.rows, x_size);

  // NOTE: the code below reads x_size * y_size elements from weights; these
  // reads are safe as long as weights.elements contains weights.rows *
  // weights.cols elements (where the element size depends on the quantization
  // type).  That requirement is checked by the params provider, e.g., by
  // EmbeddingNetworkParamsFromFlatbuffer.

  // There is some code duplication between the two main cases of the switch
  // below: the idea was to "lift" the switch outside the loops, to reduce the
  // number of tests at runtime.
  switch (weights.quant_type) {
    case QuantizationType::NONE: {
      // We compute a linear combination of the rows from |weights|, using
      // elements of x (optionally, Relu(x)) as scaling factors (the i-th row
      // gets multiplied by x[i] before being added with the other rows).  Note:
      // elements of |weights| are stored in row-major order: first the elements
      // of row #0, next the elements of row #1, etc.  In the comments below, we
      // write "weights[i][j]" to refer to the j-th element from the i-th row of
      // weights.
      const float *weight_ptr =
          reinterpret_cast<const float *>(weights.elements);
      for (int i = 0; i < x_size; ++i) {
        // Invariant 1: weight_ptr points to the beginning of the i-th row from
        // weights (i.e., weights[i][0]).
        const float scale = x[i];
        if (!apply_relu || (scale > 0)) {
          for (int j = 0; j < y_size; ++j, ++weight_ptr) {
            // Invariant 2: weight_ptr points to weights[i][j].
            y_data[j] += (*weight_ptr) * scale;
          }
        } else {
          // We don't update y_data, but we still have to move weight_ptr to the
          // next row (to satisfy Invariant 1).  We do this by adding y_size ==
          // weights.cols() (see earlier CHECK_EQ).
          weight_ptr += y_size;
        }
      }
      break;
    }
    case QuantizationType::FLOAT16: {
      // See comments for the QuantizationType::NONE case: the code is almost
      // identical, except for float16 (instead of float) and the Float16To32
      // conversion.  We could unify these two cases using a template, but since
      // this is a critical loop, don't want to risk that e.g., inlining of the
      // conversion function doesn't happen.
      const float16 *weight_ptr =
          reinterpret_cast<const float16 *>(weights.elements);
      for (int i = 0; i < x_size; ++i) {
        const float scale = x[i];
        if (!apply_relu || (scale > 0)) {
          for (int j = 0; j < y_size; ++j, ++weight_ptr) {
            y_data[j] += Float16To32(*weight_ptr) * scale;
          }
        } else {
          weight_ptr += y_size;
        }
      }
      break;
    }
    default:
      SAFTM_LOG(FATAL) << "Unsupported weights quantization type: "
                       << static_cast<int>(weights.quant_type);
  }
}
}  // namespace

void EmbeddingNetwork::ConcatEmbeddings(
    const std::vector<FeatureVector> &feature_vectors,
    std::vector<float> *concat) const {
  concat->resize(concat_layer_size_);

  // "es_index" stands for "embedding space index".
  for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
    const int concat_offset = concat_offset_[es_index];

    const EmbeddingNetworkParams::Matrix &embedding_matrix =
        embedding_matrices_[es_index];
    const int embedding_dim = embedding_matrix.cols;
    const int embedding_row_size_in_bytes =
        embedding_row_size_in_bytes_[es_index];

    const FeatureVector &feature_vector = feature_vectors[es_index];
    const int num_features = feature_vector.size();
    for (int fi = 0; fi < num_features; ++fi) {
      const FeatureType *feature_type = feature_vector.type(fi);
      int feature_offset = concat_offset + feature_type->base() * embedding_dim;
      SAFTM_CHECK_LE(feature_offset + embedding_dim, concat->size());

      // Weighted embeddings will be added starting from this address.
      float *concat_ptr = concat->data() + feature_offset;

      // Multiplier for each embedding weight.  Includes feature weight (for
      // continuous features) and quantization scale (for quantized embeddings).
      float multiplier;
      int feature_id;
      const FeatureValue feature_value = feature_vector.value(fi);
      if (feature_type->is_continuous()) {
        // Continuous features (encoded as FloatFeatureValue).
        FloatFeatureValue float_feature_value(feature_value);
        feature_id = float_feature_value.id;
        multiplier = float_feature_value.weight;
      } else {
        // Discrete features: every present feature has implicit value 1.0.
        feature_id = feature_value;
        multiplier = 1.0;
      }

      SAFTM_CHECK_GE(feature_id, 0);
      SAFTM_CHECK_LT(feature_id, embedding_matrix.rows);

      // Pointer to float / uint8 weights for relevant embedding.
      const void *embedding_data =
          (reinterpret_cast<const char *>(embedding_matrix.elements) +
           feature_id * embedding_row_size_in_bytes);

      switch (embedding_matrix.quant_type) {
        case QuantizationType::NONE: {
          const float *weights =
              reinterpret_cast<const float *>(embedding_data);
          for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
            *concat_ptr += *weights * multiplier;
          }
          break;
        }
        case QuantizationType::UINT8: {
          multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
          const uint8 *quant_weights =
              reinterpret_cast<const uint8 *>(embedding_data);
          for (int i = 0; i < embedding_dim;
               ++i, ++quant_weights, ++concat_ptr) {
            // 128 is bias for UINT8 quantization.
            *concat_ptr +=
                (static_cast<int>(*quant_weights) - 128) * multiplier;
          }
          break;
        }
        case QuantizationType::UINT4: {
          multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
          const uint8 *quant_weights =
              reinterpret_cast<const uint8 *>(embedding_data);
          for (int i = 0; i < embedding_dim / 2; ++i, ++quant_weights) {
            const uint8 qq = *quant_weights;
            concat_ptr[0] +=
                (static_cast<int>((qq & 0xF0) | 0x08) - 128) * multiplier;
            concat_ptr[1] +=
                (static_cast<int>(((qq & 0x0F) << 4) | 0x08) - 128) *
                multiplier;
            concat_ptr += 2;
          }
          break;
        }
        default:
          // We already checked (in GetMatrixRowSizeInBytes) that each embedding
          // matrix has a known quantization type.  Hence, DLOG is enough here.
          SAFTM_DLOG(ERROR) << "Unknown embeddings quantization type "
                            << static_cast<int>(embedding_matrix.quant_type);
          break;
      }
    }
  }
}

void EmbeddingNetwork::ComputeFinalScores(
    const std::vector<FeatureVector> &features,
    std::vector<float> *scores) const {
  ComputeFinalScores(features, {}, scores);
}

void EmbeddingNetwork::ComputeFinalScores(
    const std::vector<FeatureVector> &features,
    const std::vector<float> &extra_inputs, std::vector<float> *scores) const {
  // Construct the input layer for our feed-forward neural network (FFNN).
  std::vector<float> input;
  ConcatEmbeddings(features, &input);
  if (!extra_inputs.empty()) {
    input.reserve(input.size() + extra_inputs.size());
    for (int i = 0; i < extra_inputs.size(); i++) {
      input.push_back(extra_inputs[i]);
    }
  }

  // Propagate input through all layers of our FFNN.

  // Alternating storage for activations of the different layers.  We can't use
  // a single vector because all activations of the previous layer are required
  // when computing the activations of the next one.
  std::vector<float> storage[2];
  const std::vector<float> *v_in = &input;
  const int num_layers = layer_weights_.size();
  for (int i = 0; i < num_layers; ++i) {
    std::vector<float> *v_out = nullptr;
    if (i == num_layers - 1) {
      // Final layer: write results directly into |scores|.
      v_out = scores;
    } else {
      // Hidden layer: write results into the alternating storage.  The i % 2
      // trick ensures the alternation.
      v_out = &(storage[i % 2]);
    }
    const bool apply_relu = i > 0;
    SparseReluProductPlusBias(
        apply_relu, layer_weights_[i], layer_bias_[i], *v_in, v_out);
    v_in = v_out;
  }
}

EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model)
    : model_(model) {
  int offset_sum = 0;
  for (int i = 0; i < model_->embedding_num_features_size(); ++i) {
    concat_offset_.push_back(offset_sum);
    EmbeddingNetworkParams::Matrix matrix = model_->GetEmbeddingMatrix(i);
    offset_sum += matrix.cols * model_->embedding_num_features(i);

    // NOTE: each Matrix is a small struct that doesn't own the actual matrix
    // weights.  Hence, the push_back below is fast.
    embedding_matrices_.push_back(matrix);
    embedding_row_size_in_bytes_.push_back(GetMatrixRowSizeInBytes(matrix));
  }
  concat_layer_size_ = offset_sum;

  SAFTM_CHECK_EQ(model_->hidden_size(), model_->hidden_bias_size());
  for (int i = 0; i < model_->hidden_size(); ++i) {
    layer_weights_.push_back(model_->GetHiddenLayerMatrix(i));

    EmbeddingNetworkParams::Matrix bias = model_->GetHiddenLayerBias(i);
    SAFTM_CHECK_EQ(1, bias.cols);
    CheckNoQuantization(bias);
    layer_bias_.push_back(bias);
  }

  SAFTM_CHECK(model_->HasSoftmax());
  layer_weights_.push_back(model_->GetSoftmaxMatrix());

  EmbeddingNetworkParams::Matrix softmax_bias = model_->GetSoftmaxBias();
  SAFTM_CHECK_EQ(1, softmax_bias.cols);
  CheckNoQuantization(softmax_bias);
  layer_bias_.push_back(softmax_bias);
}

}  // namespace mobile
}  // namespace nlp_saft