diff options
Diffstat (limited to 'annotator/model-executor.cc')
-rw-r--r-- | annotator/model-executor.cc | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/annotator/model-executor.cc b/annotator/model-executor.cc new file mode 100644 index 0000000..7c57e8f --- /dev/null +++ b/annotator/model-executor.cc @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "annotator/model-executor.h" + +#include "annotator/quantization.h" +#include "utils/base/logging.h" + +namespace libtextclassifier3 { + +TensorView<float> ModelExecutor::ComputeLogits( + const TensorView<float>& features, tflite::Interpreter* interpreter) const { + if (!interpreter) { + return TensorView<float>::Invalid(); + } + interpreter->ResizeInputTensor(kInputIndexFeatures, features.shape()); + if (interpreter->AllocateTensors() != kTfLiteOk) { + TC3_VLOG(1) << "Allocation failed."; + return TensorView<float>::Invalid(); + } + + SetInput<float>(kInputIndexFeatures, features, interpreter); + + if (interpreter->Invoke() != kTfLiteOk) { + TC3_VLOG(1) << "Interpreter failed."; + return TensorView<float>::Invalid(); + } + + return OutputView<float>(kOutputIndexLogits, interpreter); +} + +std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::FromBuffer( + const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, + int quantization_bits) { + std::unique_ptr<TfLiteModelExecutor> executor = + TfLiteModelExecutor::FromBuffer(model_spec_buffer); + if (!executor) { + TC3_LOG(ERROR) << "Could not load TFLite model for embeddings."; + return nullptr; + } + + std::unique_ptr<tflite::Interpreter> interpreter = + executor->CreateInterpreter(); + if (!interpreter) { + TC3_LOG(ERROR) << "Could not build TFLite interpreter for embeddings."; + return nullptr; + } + + if (interpreter->tensors_size() != 2) { + return nullptr; + } + const TfLiteTensor* embeddings = interpreter->tensor(0); + if (embeddings->dims->size != 2) { + return nullptr; + } + int num_buckets = embeddings->dims->data[0]; + const TfLiteTensor* scales = interpreter->tensor(1); + if (scales->dims->size != 2 || scales->dims->data[0] != num_buckets || + scales->dims->data[1] != 1) { + return nullptr; + } + int bytes_per_embedding = embeddings->dims->data[1]; + if (!CheckQuantizationParams(bytes_per_embedding, quantization_bits, + embedding_size)) { + TC3_LOG(ERROR) << "Mismatch in quantization parameters."; + return nullptr; + } + + return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor( + std::move(executor), quantization_bits, num_buckets, bytes_per_embedding, + embedding_size, scales, embeddings, std::move(interpreter))); +} + +TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor( + std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits, + int num_buckets, int bytes_per_embedding, int output_embedding_size, + const TfLiteTensor* scales, const TfLiteTensor* embeddings, + std::unique_ptr<tflite::Interpreter> interpreter) + : executor_(std::move(executor)), + quantization_bits_(quantization_bits), + num_buckets_(num_buckets), + bytes_per_embedding_(bytes_per_embedding), + output_embedding_size_(output_embedding_size), + scales_(scales), + embeddings_(embeddings), + interpreter_(std::move(interpreter)) {} + +bool TFLiteEmbeddingExecutor::AddEmbedding( + const TensorView<int>& sparse_features, float* dest, int dest_size) const { + if (dest_size != output_embedding_size_) { + TC3_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: " + << dest_size << " " << output_embedding_size_; + return false; + } + const int num_sparse_features = sparse_features.size(); + for (int i = 0; i < num_sparse_features; ++i) { + const int bucket_id = sparse_features.data()[i]; + if (bucket_id >= num_buckets_) { + return false; + } + + if (!DequantizeAdd(scales_->data.f, embeddings_->data.uint8, + bytes_per_embedding_, num_sparse_features, + quantization_bits_, bucket_id, dest, dest_size)) { + return false; + } + } + return true; +} + +} // namespace libtextclassifier3 |