aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/custom_ops/python/tflite_text_api.py
blob: 1466df297a609b43d3ee19303157d1afc4bff190 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Wrapped TF.Text friendly to Tensorflow Lite conversion."""

import tensorflow as tf
import tensorflow_text as tf_text


class WhitespaceTokenizer(tf_text.Tokenizer):
  """TFLite friendly API for tensorflow_text.WhitspaceTokenizer.tokenize.

  The strings are split on ICU defined whitespace characters. These
  whitespace characters are dropped. See more details in
  https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WhitespaceTokenizer.md

  Does not currently support tokenize_with_offsets().
  """

  def __init__(self):
    super(WhitespaceTokenizer, self).__init__()
    self._tokenizer = tf_text.WhitespaceTokenizer()

  def tokenize(self, input_tensor):
    """Tokenize input strings.

    Args:
      input_tensor: A `Tensor` of UTF-8 strings with rank 0, 1 or 2.

    Returns:
      A `RaggedTensor` of tokenized text. The returned shape is the shape of the
      input tensor with an added ragged dimension for tokens of each string.
    """

    @tf.function(experimental_implements='name: "tftext:WhitespaceTokenizer"')
    def func(input_tensor):
      return self._tokenizer.tokenize(input_tensor)

    return func(input_tensor)


def ngrams(data,
           width,
           axis=-1,
           reduction_type=None,
           string_separator=' ',
           name=None):
  """TFLite friendly API for tensorflow_text.ngrams.

  Creates a tensor of n-grams based data, a token tensor. See more details in
  https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/ngrams.md

  Args:
    data: The data to reduce.  Must be convertible into a tf.Tensor or a
      tf.RaggedTensor (in which case it will be deconstructed into its component
      tf.Tensors).
    width: The width of the ngram window. If there is not sufficient data to
      fill out the ngram window, the resulting ngram will be empty.
    axis: The axis to create ngrams along. Note that for string join reductions,
      only axis '-1' is supported; for other reductions, any positive or
      negative axis can be used. Should be a constant.
    reduction_type: A member of the Reduction enum. Should be a constant.
      Currently supports:
      * `Reduction.STRING_JOIN`: Join strings in the window. Note that axis must
        be -1 here.
    string_separator: The separator string used for `Reduction.STRING_JOIN`.
      Ignored otherwise. Must be a string constant, not a Tensor.
    name: The op name.

  Returns:
    A tensor of ngrams.  If `data` is a ragged tensor, this will be a ragged
    tensor.  Otherwise it will be a plain tensor.
  """

  if reduction_type is not tf_text.Reduction.STRING_JOIN:
    # TODO(b/162082752): Provide support for Reduction.SUM and Reduction.MEAN
    raise tf.errors.InvalidArgumentError(
        None, None, 'only Reduction.STRING_JOIN is currently supported')

  if reduction_type is tf_text.Reduction.STRING_JOIN and axis != -1:
    raise tf.errors.InvalidArgumentError(
        None, None, 'For Reduction.STRING_JOIN, axis must be -1')

  experimental_implements = [
      'name: "tftext:Ngrams"',
      'attr { key: "width" value { i: %d } }' % width,
      'attr { key: "axis" value { i: %d } }' % axis,
      'attr { key: "reduction_type" value { s: "STRING_JOIN" } }',
      'attr { key: "string_separator" value { s: "%s" } }' % string_separator,
  ]
  experimental_implements = ' '.join(experimental_implements)

  if isinstance(data, tf.RaggedTensor):

    # Since `data` can not be converted directly into a Tensor, we define
    # ragged_func() which takes a deconstructed tf.RaggedTensor
    # (one flat_values tensor and N row_splits tensors), pass it the
    # deconstructed version of `data`, and then immediately reconstruct it
    # within ragged_func().
    @tf.function(experimental_implements=experimental_implements)
    def ragged_func(values, *args):
      ragged_tensor = tf.RaggedTensor.from_nested_row_splits(
          flat_values=values, nested_row_splits=args)
      return tf_text.ngrams(ragged_tensor, width, axis, reduction_type,
                            string_separator, name)

    return ragged_func(data.flat_values, *data.nested_row_splits)

  @tf.function(experimental_implements=experimental_implements)
  def func(data):
    return tf_text.ngrams(data, width, axis, reduction_type, string_separator,
                          name)

  return func(data)