aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
blob: 54b34e4e33196837e94e231bfcf6535e2c01b90b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
 * Sentencepiece tflite detokenizer implementation.
 */
#include <algorithm>
#include <iterator>

#include "flatbuffers/flexbuffers.h"  // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h"
#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h"

namespace tflite {
namespace ops {
namespace custom {
namespace sentencepiece {
namespace detokenizer {

constexpr int kOutputValuesInd = 0;
// Initializes text encoder object from serialized parameters.
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
                 size_t /*length*/) {
  return nullptr;
}
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  // TODO(mgubin): Add checks for input and output tensors.
  TfLiteTensor& output_values =
      context->tensors[node->outputs->data[kOutputValuesInd]];
  SetTensorToDynamic(&output_values);
  // TODO(mgubin): Check input types.

  return kTfLiteOk;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  const TfLiteTensor& model_tensor =
      context->tensors[node->inputs->data[tensorflow::ops::kSPModelIndex]];
  const auto model_buffer_data = model_tensor.data.data;
  const TfLiteTensor& input_encoded =
      context->tensors[node->inputs->data[tensorflow::ops::kInputIndex]];
  const int32_t* input_encoded_data = input_encoded.data.i32;
  const TfLiteTensor& input_splits =
      context->tensors[node->inputs->data[tensorflow::ops::kInputSplits]];
  const int num_of_sentences = NumElements(input_splits.dims) - 1;
  const int32_t* input_splits_data = input_splits.data.i32;

  DynamicBuffer buf;

  std::vector<int> codes_for_split;
  int input_offset = 0;
  for (int i = 0; i < num_of_sentences; i++) {
    // Create a vector of int32 from input according to spans.
    const int split_size = input_splits_data[i + 1] - input_splits_data[i];
    codes_for_split.clear();
    std::copy(input_encoded_data + input_offset,
              input_encoded_data + input_offset + split_size,
              std::back_inserter(codes_for_split));
    const auto res = DecodeString(codes_for_split, model_buffer_data);
    TF_LITE_ENSURE_MSG(context, res.type == DecoderResultType::SUCCESS,
                       "Sentencepiece decoding failed");
    buf.AddString(res.decoded.data(), res.decoded.length());
    input_offset += split_size;
  }
  TfLiteTensor& output_values =
      context->tensors[node->outputs->data[kOutputValuesInd]];
  buf.WriteToTensor(&output_values, nullptr);
  return kTfLiteOk;
}
}  // namespace detokenizer
}  // namespace sentencepiece

TfLiteRegistration* Register_SENTENCEPIECE_DETOKENIZER() {
  static TfLiteRegistration r = {
      sentencepiece::detokenizer::Initialize, sentencepiece::detokenizer::Free,
      sentencepiece::detokenizer::Prepare, sentencepiece::detokenizer::Eval};
  return &r;
}

}  // namespace custom
}  // namespace ops
}  // namespace tflite