aboutsummaryrefslogtreecommitdiff
path: root/fcp/tensorflow/dictionary_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'fcp/tensorflow/dictionary_ops.py')
-rw-r--r--fcp/tensorflow/dictionary_ops.py372
1 files changed, 372 insertions, 0 deletions
diff --git a/fcp/tensorflow/dictionary_ops.py b/fcp/tensorflow/dictionary_ops.py
new file mode 100644
index 0000000..b168087
--- /dev/null
+++ b/fcp/tensorflow/dictionary_ops.py
@@ -0,0 +1,372 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Python and TensorFlow functions to work with dictionaries.
+
+Please see fcp/dictionary/dictionary.h for more on this type of
+dictionary.
+
+Python Classes:
+
+* `Dictionary`: A Python analogue to fcp/dictionary/dictionary.h
+ that includes additional helpers for dictionary construction.
+
+TensorFlow ops:
+
+* dictionary_size
+ Queries the size of a dictionary.
+
+* dictionary_lookup
+ Looks up ids for string tokens in the dictionary.
+
+* dictionary_reverse_lookup
+ Looks up string tokens from ids in the dictionary.
+
+Canonical use (note that the dictionary is known at graph construction time):
+ dictionary = Dictionary.from_tokens(
+ tokens=['some', 'token', 'list'], unk_id=0,
+ vocabulary_type=VocabularyType.TOKEN_INDEX)
+
+ with tf.Graph().as_default():
+ tokens = tf.compat.v1.placeholder(tf.String, ...) # Tokens to look up.
+ ids = dictionary_lookup(
+ tokens, dictionary.dictionary_description_proto)
+"""
+
+import collections
+import enum
+
+import tensorflow as tf
+
+from fcp.dictionary.dictionary_pb2 import DictionaryDescription # pylint: disable=g-importing-member
+from fcp.tensorflow.gen_dictionary_ops import dictionary_lookup
+from fcp.tensorflow.gen_dictionary_ops import dictionary_reverse_lookup
+from fcp.tensorflow.gen_dictionary_ops import dictionary_size
+
+_dictionary_ops = tf.load_op_library(
+ tf.compat.v1.resource_loader.get_path_to_datafile('./_dictionary_ops.so'))
+
+
+def ignore_ids_mask(token_ids, ignore_ids, name=None):
+ """Creates a bool mask with True everywhere token_ids is not in ignore_ids."""
+ with tf.op_scope([token_ids, ignore_ids], name, 'ignore_ids_mask'):
+ # Yay broadcasting
+ all_check = tf.not_equal(tf.expand_dims(token_ids, -1), ignore_ids)
+ check = tf.reduce_all(all_check, reduction_indices=tf.rank(all_check) - 1)
+ check.set_shape(token_ids.get_shape())
+ return check
+
+
+def mask_and_replace_padding(token_ids,
+ lengths,
+ eos_id=None,
+ special_tokens=(),
+ name=None):
+ """Creates a mask of valid tokens and sets padded values in id space.
+
+ This creates a mask the same shape as token_ids with a boolean indicating
+ if the id was a valid token (i.e not padding or a special token). If
+ provided, this also remaps tokens after lengths to the eos_id. Since the
+ dictionary doesn't map tokens to eos or bos ids, it would generally be the
+ unknown token id which is not correct if you need to predict the eos.
+
+ Args:
+ token_ids: A matrix `Tensor` of integer ids.
+ lengths: A vector `Tensor` of lengths for each row in token_ids.
+ eos_id: The end of sequence id, if provided then all token ids after length
+ in a row will be replaced with `eos_id`.
+ special_tokens: An iterable of special tokens for ids that are not
+ considered valid.
+ name: Name scope for these ops.
+
+ Returns:
+ token_ids: `token_ids` with all tokens after a row's length replaced with
+ eos if provided.
+ mask: A bool `Tensor` the same shape as `token_ids` indicating which tokens
+ are valid.
+ """
+ with tf.op_scope([token_ids, lengths, eos_id, special_tokens], name,
+ 'mask_and_replace_padding'):
+ ranges = tf.range(0, tf.gather(tf.shape(token_ids), 1))
+
+ # Yay! Broadcasting.
+ selected = tf.less(ranges, tf.expand_dims(lengths, -1))
+
+ if eos_id is not None:
+ token_ids = tf.where(
+ selected, token_ids,
+ tf.fill(
+ tf.shape(token_ids), tf.constant(eos_id, dtype=token_ids.dtype)))
+ if special_tokens:
+ mask = tf.logical_and(
+ ignore_ids_mask(token_ids, special_tokens), selected)
+ else:
+ mask = selected
+ return token_ids, mask
+
+tf.no_gradient('DictionarySize')
+tf.no_gradient('DictionaryLookup')
+tf.no_gradient('DictionaryReverseLookup')
+
+
+class VocabularyType(enum.Enum):
+ """Valid vocabulary types for Dictionary construction.
+
+ TOKEN_INDEX: dictionary.dictionary_description contains an embedded map of
+ string names stored in order with ids assigned starting from the lowest
+ non-special id. Preserves order but is not compact.
+ """
+ TOKEN_INDEX = 3
+
+
+class Dictionary(object):
+ """Utility for working with fcp/dictionary/ via TensorFlow."""
+
+ def __init__(
+ self,
+ dictionary_description
+ ):
+ """Creates a dictionary from a dictionary_description.
+
+ Use static from_* constructor methods for building dictionaries from
+ common data types.
+
+ Args:
+ dictionary_description: A `dictionary_pb2.DictionaryDescription`
+ describing the dictionary.
+
+ Raises:
+ ValueError: An invalid dictionary description.
+ """
+ if not isinstance(dictionary_description, DictionaryDescription):
+ raise ValueError('Expected a DictionaryDescription')
+ if not dictionary_description.HasField('vocabulary'):
+ raise ValueError('dictionary_description has no vocabulary')
+
+ self._dictionary_description = dictionary_description
+
+ # Lazily constructed fields for lookup.
+ self._lookup_graph = None
+ self._lookup_placeholder = None
+ self._lookup_result = None
+ self._reverse_lookup_placeholder = None
+ self._reverse_lookup_result = None
+
+ @classmethod
+ def from_tokens(
+ cls,
+ tokens,
+ bos_id=None,
+ eos_id=None,
+ unk_id=None,
+ output_blocklist_tokens=None,
+ output_size=None,
+ vocabulary_type=VocabularyType.TOKEN_INDEX
+ ):
+ """Creates a dictionary from a provided list of tokens.
+
+ The id mappings to token ids depend on the vocabulary_type requested.
+
+ NB: the special tokens must be the first ids [0, num-specials)
+
+ Args:
+ tokens: An unordered iterable of tokens for the dictionary.
+ bos_id: Token id for start of sequence.
+ eos_id: Token id for end of sequence.
+ unk_id: Token id for unknown words.
+ output_blocklist_tokens: A list of vocabulary tokens that should be
+ filtered from predictions (e.g., punctuation, bad words etc.).
+ output_size: If a positive integer, tokens with ids greater than this are
+ automatically added to the output blocklist.
+ vocabulary_type: `VocabularyType` to use, defaults to TOKEN_INDEX.
+
+ Returns:
+ A `Dictionary` instance.
+
+ Raises:
+ ValueError: If the special tokens don't have the lowest ids.
+ ValueError: If there are duplicates in tokens.
+ """
+ dictionary_description = DictionaryDescription()
+
+ # Special ids.
+ special_ids = []
+ if unk_id is not None:
+ dictionary_description.special_ids.unk = unk_id
+ special_ids.append(unk_id)
+ if bos_id is not None:
+ dictionary_description.special_ids.bos = bos_id
+ special_ids.append(bos_id)
+ if eos_id is not None:
+ dictionary_description.special_ids.eos = eos_id
+ special_ids.append(eos_id)
+ if sorted(special_ids) != list(range(len(special_ids))):
+ raise ValueError(
+ 'Special ids must be the first items of the dictionary starting at 0'
+ 'or None. eos: %s; bos %s; unk: %s' % (eos_id, bos_id, unk_id))
+
+ # Vocabulary.
+ if len(tokens) != len(set(tokens)):
+ raise ValueError('Duplicate tokens provided')
+ for token in tokens:
+ if not isinstance(token, (str, bytes)):
+ raise ValueError('Bad type in tokens %s' % token)
+ if vocabulary_type == VocabularyType.TOKEN_INDEX:
+ for token in tokens:
+ dictionary_description.vocabulary.index.token.append(token)
+ else:
+ raise AssertionError('Unsupported vocabulary_type: %s' % vocabulary_type)
+
+ # Output blocklist.
+ output_blocklist_tokens = list(output_blocklist_tokens or [])
+ if output_size:
+ assert output_size >= len(special_ids), (
+ 'Cannot blocklist special tokens via output_size.')
+ assert isinstance(tokens, list) # Make sure order preserving pre-slice.
+ output_blocklist_tokens.extend(tokens[output_size - len(special_ids):])
+ for token in output_blocklist_tokens:
+ assert token in tokens, "Unexpected blocklist token: '%s'" % token
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess:
+ output_blocklist_ids = sess.run(
+ dictionary_lookup(output_blocklist_tokens,
+ dictionary_description.SerializeToString()))
+ dictionary_description.output_blocklist_ids.id.extend(
+ sorted(output_blocklist_ids))
+ assert (len(set(dictionary_description.output_blocklist_ids.id)) == len(
+ output_blocklist_tokens)), 'blocklist contains dups or unks?'
+
+ # Return completed dictionary.
+ return cls(
+ dictionary_description=dictionary_description)
+
+ @classmethod
+ def from_dictionary_description(cls,
+ dictionary_description):
+ """Returns a Dictionary from a DictionaryDescription."""
+ return cls(
+ dictionary_description=dictionary_description)
+
+ def _get_lookup_graph(self):
+ """Returns a graph to use for lookup, reverse lookup, and size queries."""
+ if self._lookup_graph is None:
+ self._lookup_graph = tf.Graph()
+ serialized_description_proto = (
+ self._dictionary_description.SerializeToString())
+ with self._lookup_graph.as_default():
+ self._lookup_placeholder = tf.compat.v1.placeholder(
+ tf.string, shape=None)
+ self._reverse_lookup_placeholder = tf.compat.v1.placeholder(
+ tf.int64, shape=None)
+
+ # Use Dictionary(Op) (without blob) variants.
+ self._lookup_result = dictionary_lookup(
+ self._lookup_placeholder,
+ dictionary_description_proto=serialized_description_proto)
+ self._reverse_lookup_result = dictionary_reverse_lookup(
+ self._reverse_lookup_placeholder,
+ dictionary_description_proto=serialized_description_proto)
+ self._size_result = dictionary_size(
+ dictionary_description_proto=serialized_description_proto)
+
+ return self._lookup_graph
+
+ def lookup(self, tokens):
+ """Maps a list of tokens to a list of ids.
+
+ Args:
+ tokens: A list of tokens to lookup.
+
+ Returns:
+ A list of token ids of the same size.
+
+ Raises:
+ ValueError: If tokens is not a list.
+ """
+ if not isinstance(tokens, list):
+ raise ValueError('lookup expected a list of tokens.')
+
+ with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
+ return sess.run(self._lookup_result, {
+ self._lookup_placeholder: tokens
+ }).tolist()
+
+ def reverse_lookup(self, ids):
+ """Maps a list of ids to tokens.
+
+ Args:
+ ids: A list of ids to map back to tokens.
+
+ Returns:
+ A list of tokens corresponding to those ids.
+
+ Raises:
+ ValueError: If ids is not a list.
+ """
+ if not isinstance(ids, list):
+ raise ValueError('reverse_lookup expected a list of ids.')
+ with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
+ return list(
+ sess.run(self._reverse_lookup_result,
+ {self._reverse_lookup_placeholder: ids}))
+
+ @property
+ def special_ids(self):
+ """Returns a list of special token ids."""
+ return [t for t in [self.unk_id, self.bos_id, self.eos_id] if t is not None]
+
+ @property
+ def eos_id(self):
+ eos_id = self._dictionary_description.special_ids.eos
+ return eos_id if eos_id >= 0 else None
+
+ @property
+ def bos_id(self):
+ bos_id = self._dictionary_description.special_ids.bos
+ return bos_id if bos_id >= 0 else None
+
+ @property
+ def unk_id(self):
+ unk_id = self._dictionary_description.special_ids.unk
+ return unk_id if unk_id >= 0 else None
+
+ @property
+ def size(self):
+ with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
+ return sess.run(self._size_result)
+
+ @property
+ def output_blocklist_ids(self):
+ return list(self._dictionary_description.output_blocklist_ids.id)
+
+ @property
+ def output_blocklist_tokens(self):
+ return self.reverse_lookup(self.output_blocklist_ids)
+
+ @property
+ def tokens(self):
+ return self.reverse_lookup(list(range(len(self.special_ids), self.size)))
+
+ @property
+ def dictionary_description_proto(self):
+ """Serialized proto containing self.dictionary_description."""
+ return self.dictionary_description.SerializeToString()
+
+ @property
+ def dictionary_description(self):
+ """Returns the `DictionaryDescription` proto describing this dictionary.
+ """
+ desc = self._dictionary_description
+ return desc
+
+ def __len__(self):
+ return self.size