summaryrefslogtreecommitdiff
path: root/annotator/model-executor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'annotator/model-executor.cc')
-rw-r--r--annotator/model-executor.cc124
1 files changed, 124 insertions, 0 deletions
diff --git a/annotator/model-executor.cc b/annotator/model-executor.cc
new file mode 100644
index 0000000..7c57e8f
--- /dev/null
+++ b/annotator/model-executor.cc
@@ -0,0 +1,124 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/model-executor.h"
+
+#include "annotator/quantization.h"
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+TensorView<float> ModelExecutor::ComputeLogits(
+ const TensorView<float>& features, tflite::Interpreter* interpreter) const {
+ if (!interpreter) {
+ return TensorView<float>::Invalid();
+ }
+ interpreter->ResizeInputTensor(kInputIndexFeatures, features.shape());
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ TC3_VLOG(1) << "Allocation failed.";
+ return TensorView<float>::Invalid();
+ }
+
+ SetInput<float>(kInputIndexFeatures, features, interpreter);
+
+ if (interpreter->Invoke() != kTfLiteOk) {
+ TC3_VLOG(1) << "Interpreter failed.";
+ return TensorView<float>::Invalid();
+ }
+
+ return OutputView<float>(kOutputIndexLogits, interpreter);
+}
+
+std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::FromBuffer(
+ const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
+ int quantization_bits) {
+ std::unique_ptr<TfLiteModelExecutor> executor =
+ TfLiteModelExecutor::FromBuffer(model_spec_buffer);
+ if (!executor) {
+ TC3_LOG(ERROR) << "Could not load TFLite model for embeddings.";
+ return nullptr;
+ }
+
+ std::unique_ptr<tflite::Interpreter> interpreter =
+ executor->CreateInterpreter();
+ if (!interpreter) {
+ TC3_LOG(ERROR) << "Could not build TFLite interpreter for embeddings.";
+ return nullptr;
+ }
+
+ if (interpreter->tensors_size() != 2) {
+ return nullptr;
+ }
+ const TfLiteTensor* embeddings = interpreter->tensor(0);
+ if (embeddings->dims->size != 2) {
+ return nullptr;
+ }
+ int num_buckets = embeddings->dims->data[0];
+ const TfLiteTensor* scales = interpreter->tensor(1);
+ if (scales->dims->size != 2 || scales->dims->data[0] != num_buckets ||
+ scales->dims->data[1] != 1) {
+ return nullptr;
+ }
+ int bytes_per_embedding = embeddings->dims->data[1];
+ if (!CheckQuantizationParams(bytes_per_embedding, quantization_bits,
+ embedding_size)) {
+ TC3_LOG(ERROR) << "Mismatch in quantization parameters.";
+ return nullptr;
+ }
+
+ return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor(
+ std::move(executor), quantization_bits, num_buckets, bytes_per_embedding,
+ embedding_size, scales, embeddings, std::move(interpreter)));
+}
+
+TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor(
+ std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
+ int num_buckets, int bytes_per_embedding, int output_embedding_size,
+ const TfLiteTensor* scales, const TfLiteTensor* embeddings,
+ std::unique_ptr<tflite::Interpreter> interpreter)
+ : executor_(std::move(executor)),
+ quantization_bits_(quantization_bits),
+ num_buckets_(num_buckets),
+ bytes_per_embedding_(bytes_per_embedding),
+ output_embedding_size_(output_embedding_size),
+ scales_(scales),
+ embeddings_(embeddings),
+ interpreter_(std::move(interpreter)) {}
+
+bool TFLiteEmbeddingExecutor::AddEmbedding(
+ const TensorView<int>& sparse_features, float* dest, int dest_size) const {
+ if (dest_size != output_embedding_size_) {
+ TC3_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: "
+ << dest_size << " " << output_embedding_size_;
+ return false;
+ }
+ const int num_sparse_features = sparse_features.size();
+ for (int i = 0; i < num_sparse_features; ++i) {
+ const int bucket_id = sparse_features.data()[i];
+ if (bucket_id >= num_buckets_) {
+ return false;
+ }
+
+ if (!DequantizeAdd(scales_->data.f, embeddings_->data.uint8,
+ bytes_per_embedding_, num_sparse_features,
+ quantization_bits_, bucket_id, dest, dest_size)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace libtextclassifier3