aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc')
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc129
1 files changed, 129 insertions, 0 deletions
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
new file mode 100644
index 00000000..8309a6a2
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
@@ -0,0 +1,129 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/**
+ * Sentencepiece tflite tokenizer implementation.
+ */
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
+#include "flatbuffers/flexbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+namespace tokenizer {
+
+constexpr int kOutputValuesInd = 0;
+constexpr int kOutputSplitsInd = 1;
+
+namespace {
+TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
+ TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
+ int index = 0;
+ for (const int size : sizes) {
+ array_size->data[index++] = size;
+ }
+ return array_size;
+}
+} // namespace
+
+// Initializes text encoder object from serialized parameters.
+void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
+ size_t /*length*/) {
+ return nullptr;
+}
+void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // TODO(mgubin): Add checks for input and output tensors.
+ TfLiteTensor& output_values =
+ context->tensors[node->outputs->data[kOutputValuesInd]];
+ SetTensorToDynamic(&output_values);
+
+ TfLiteTensor& output_splits =
+ context->tensors[node->outputs->data[kOutputSplitsInd]];
+ SetTensorToDynamic(&output_splits);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor& model_tensor =
+ context->tensors[node->inputs->data[tensorflow::ops::kSPModelIndex]];
+ const auto model_buffer_data = model_tensor.data.data;
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[tensorflow::ops::kInputIndex]];
+
+ const TfLiteTensor add_bos_tensor =
+ context->tensors[node->inputs->data[tensorflow::ops::kAddBOSInput]];
+ const bool add_bos = add_bos_tensor.data.b[0];
+ const TfLiteTensor add_eos_tensor =
+ context->tensors[node->inputs->data[tensorflow::ops::kAddEOSInput]];
+ const bool add_eos = add_eos_tensor.data.b[0];
+ const TfLiteTensor reverse_tensor =
+ context->tensors[node->inputs->data[tensorflow::ops::kReverseInput]];
+ const bool reverse = reverse_tensor.data.b[0];
+
+ std::vector<int32> encoded;
+ std::vector<int32> splits;
+ const int num_strings = tflite::GetStringCount(&input_text);
+ for (int i = 0; i < num_strings; ++i) {
+ const auto strref = tflite::GetString(&input_text, i);
+ const auto res = EncodeString(std::string(strref.str, strref.len),
+ model_buffer_data, add_bos, add_eos, reverse);
+ TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
+ "Sentencepiece conversion failed");
+ std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
+ splits.emplace_back(encoded.size());
+ }
+
+ TfLiteTensor& output_values =
+ context->tensors[node->outputs->data[kOutputValuesInd]];
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(
+ context, &output_values,
+ CreateSizeArray({static_cast<int>(encoded.size())})));
+ int32_t* output_values_flat = output_values.data.i32;
+ std::copy(encoded.begin(), encoded.end(), output_values_flat);
+ TfLiteTensor& output_splits =
+ context->tensors[node->outputs->data[kOutputSplitsInd]];
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(
+ context, &output_splits,
+ CreateSizeArray({static_cast<int>(splits.size() + 1)})));
+ int32_t* output_splits_flat = output_splits.data.i32;
+ *output_splits_flat = 0;
+ std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
+ return kTfLiteOk;
+}
+} // namespace tokenizer
+} // namespace sentencepiece
+
+TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
+ static TfLiteRegistration r = {
+ sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
+ sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite