aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py')
-rw-r--r--tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py125
1 files changed, 125 insertions, 0 deletions
diff --git a/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py
new file mode 100644
index 00000000..21efed56
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py
@@ -0,0 +1,125 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Lint as: python3
+"""Python class that implements Sentencepiece tokenizer.
+
+It follows TF.text designers design.
+
+"""
+import tensorflow.compat.v2 as tf # pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.ops.ragged import ragged_tensor # pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.framework import load_library
+from tensorflow.python.platform import resource_loader
+gen_sentencepiece_detokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_detokenizer_op.so'))
+from tensorflow.python.framework import load_library
+from tensorflow.python.platform import resource_loader
+gen_sentencepiece_tokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_tokenizer_op.so'))
+from tensorflow_lite_support.custom_ops.kernel.sentencepiece.py import pywrap_model_converter as model_converter
+
+
+class SentencepieceTokenizer:
+ """Sentencepiece tokenizer with tf.text interface."""
+
+ def __init__(self, model, reverse=False, add_bos=False, add_eos=False):
+ converted_model = model_converter.convert_sentencepiece_model(model)
+ converted_model_detokenizer = model_converter.convert_sentencepiece_model_for_decoder(
+ model)
+ # Use uint8 tensor as a buffer for the model to avoid any possible changes,
+ # for example truncation by '\0'.
+ self._converted_model = tf.constant(list(converted_model), dtype=tf.uint8)
+ self._converted_model_detokenizer = tf.constant(
+ list(converted_model_detokenizer), dtype=tf.uint8)
+ self._vocab_size = model_converter.get_vocabulary_size(converted_model)
+ self._reverse = reverse
+ self._add_bos = add_bos
+ self._add_eos = add_eos
+
+ def tokenize(self, inputs):
+ """The main tokenization function."""
+ input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(inputs)
+ if input_tensor.shape.ndims is None:
+ raise ValueError("Rank of input_tensor must be statically known.")
+ if ragged_tensor.is_ragged(input_tensor):
+ # Ensure that input has row_split_dtype is int32
+ input_tensor = input_tensor.with_row_splits_dtype(tf.int32)
+ # Recursively process the values of the ragged tensor.
+ tokens = self.tokenize(input_tensor.flat_values)
+ return input_tensor.with_flat_values(tokens)
+ else:
+ if input_tensor.shape.ndims > 1:
+ # Convert the input tensor to ragged and process it.
+ return self.tokenize(
+ tf.RaggedTensor.from_tensor(
+ input_tensor, row_splits_dtype=tf.int32))
+ elif input_tensor.shape.ndims == 0:
+ tokens = self.tokenize(tf.stack([input_tensor]))
+ return tokens.values
+ else:
+ # Our rank 1 tensor is the correct shape, so we can process it as
+ # normal.
+ (output_values, row_splits) = (
+ gen_sentencepiece_tokenizer_op.tf_sentencepiece_tokenize_op(
+ self._converted_model, input_tensor, 0, 0, self._add_bos,
+ self._add_eos, self._reverse))
+ tokens = tf.RaggedTensor.from_nested_row_splits(
+ flat_values=output_values,
+ nested_row_splits=[row_splits],
+ validate=False)
+ return tokens
+
+ def detokenize(self, input): # pylint: disable=redefined-builtin
+ """Detokenizes tokens into preprocessed text.
+
+ Args:
+ input: A `RaggedTensor` or `Tensor` with int32 encoded text with rank >=
+ 1.
+
+ Returns:
+ A N-1 dimensional string Tensor or RaggedTensor of the detokenized text.
+ """
+ input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
+ if input_tensor.shape.ndims is None:
+ raise ValueError("Rank of input_tensor must be statically known.")
+ if input_tensor.shape.ndims == 0:
+ raise ValueError("Rank of input_tensor must be at least 1.")
+ if ragged_tensor.is_ragged(input_tensor):
+ if input_tensor.flat_values.shape.ndims > 1:
+ # If the flat_values of our ragged tensor is multi-dimensional, we can
+ # process it separately and our output will have the same nested
+ # splits as our input.
+ tokens = self.detokenize(input_tensor.flat_values)
+ return input_tensor.with_flat_values(tokens)
+ elif input_tensor.ragged_rank > 1:
+ # Recursively process the values of the ragged tensor.
+ tokens = self.detokenize(input_tensor.values)
+ return input_tensor.with_values(tokens)
+ else:
+ return gen_sentencepiece_detokenizer_op.tf_sentencepiece_detokenize_op(
+ self._converted_model_detokenizer, input_tensor.flat_values,
+ input_tensor.row_splits)
+ else:
+ if input_tensor.shape.ndims > 1:
+ # Convert the input tensor to ragged and process it.
+ return self.detokenize(
+ tf.RaggedTensor.from_tensor(
+ input_tensor, row_splits_dtype=tf.int32))
+ else:
+ tokens = self.detokenize(tf.stack([input_tensor]))
+ return tf.reshape(tokens, [])
+
+ def vocab_size(self):
+ """Returns size of the vocabulary in Sentencepiece model."""
+ return self._vocab_size