diff options
Diffstat (limited to 'fcp/tensorflow/dictionary_ops_test.py')
-rw-r--r-- | fcp/tensorflow/dictionary_ops_test.py | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/fcp/tensorflow/dictionary_ops_test.py b/fcp/tensorflow/dictionary_ops_test.py new file mode 100644 index 0000000..9fee48a --- /dev/null +++ b/fcp/tensorflow/dictionary_ops_test.py @@ -0,0 +1,110 @@ +from absl.testing import parameterized +import tensorflow as tf +from google.protobuf import text_format +from fcp.dictionary import dictionary_pb2 +from fcp.tensorflow import dictionary_ops + + +class DictionaryOpsTest(tf.test.TestCase, parameterized.TestCase): + + def test_direct_tf_use_literal_dictionary(self): + dictionary = dictionary_pb2.DictionaryDescription() + text_format.Merge( + 'special_ids: < unk: 0 > ' + 'vocabulary: < ' + ' index: < token: "a" token: "b" token: "c" token: "d" >' + '>', + dictionary) + + lookup = dictionary_ops.dictionary_lookup( + tf.constant(['a', 'b', 'a', 'a', 'd', 'X']), + dictionary_description_proto=dictionary.SerializeToString()) + with tf.compat.v1.Session() as sess: + tokenized = sess.run(lookup) + self.assertEqual([1, 2, 1, 1, 4, 0], tokenized.tolist()) + + @parameterized.named_parameters( + ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) + def test_build_dictionary_with_output_blocklist(self, vocabulary_type): + # Build a dictionary, explicitly blocklisting the first token and + # implicitly blocklisting the last token via output_size. + dictionary = dictionary_ops.Dictionary.from_tokens( + ['01', '02', '10', '11'], + unk_id=0, + output_blocklist_tokens=['01'], + output_size=4, + vocabulary_type=vocabulary_type) + + if vocabulary_type in ( + dictionary_ops.VocabularyType.TOKEN_INDEX, + ): + result = dictionary_ops.dictionary_lookup( + [['01', '02', '10', '11', '12']], + dictionary_description_proto=dictionary.dictionary_description_proto) + + with tf.compat.v1.Session() as sess: + tokenized = sess.run(result) + self.assertEqual([[1, 2, 3, 4, 0]], tokenized.tolist()) + self.assertEqual( + [1, 4], list(dictionary.dictionary_description.output_blocklist_ids.id)) + + @parameterized.named_parameters( + ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) + def test_build_dictionary(self, vocabulary_type): + dictionary = dictionary_ops.Dictionary.from_tokens( + ['A', 'a', 'B', 'c'], + unk_id=0, + vocabulary_type=vocabulary_type) + + result = dictionary_ops.dictionary_lookup( + [['A', 'a', 'B', 'b', 'C', 'c', 'D', 'd']], + dictionary_description_proto=dictionary.dictionary_description_proto) + expected = [[1, 2, 3, 0, 0, 4, 0, 0]] + with tf.compat.v1.Session() as sess: + tokenized = sess.run(result) + self.assertEqual(expected, tokenized.tolist()) + + @parameterized.named_parameters( + ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) + def test_dictionary_should_raise_with_duplicate_tokens(self, vocabulary_type): + with self.assertRaisesRegex(ValueError, 'Duplicate tokens'): + dictionary_ops.Dictionary.from_tokens(['01', '02', '11', '10', '11'], + vocabulary_type=vocabulary_type) + + @parameterized.named_parameters( + ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) + def test_lookup_in_python(self, vocabulary_type): + dictionary = dictionary_ops.Dictionary.from_tokens( + ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type) + self.assertLen(dictionary, 5) + self.assertListEqual([1, 2, 3, 4, 0], + dictionary.lookup(['01', '02', '10', '11', '12'])) + + @parameterized.named_parameters( + ('token_index', dictionary_ops.VocabularyType.TOKEN_INDEX)) + def test_reverse_lookup_in_python(self, vocabulary_type): + dictionary = dictionary_ops.Dictionary.from_tokens( + ['01', '02', '10', '11'], unk_id=0, vocabulary_type=vocabulary_type) + self.assertLen(dictionary, 5) + rlookup = [ + t.decode('utf-8') for t in dictionary.reverse_lookup([3, 2, 1, 4, 0]) + ] + self.assertListEqual(['10', '02', '01', '11', ''], rlookup) + + def test_literal_dictionary_in_python(self): + dictionary_description = dictionary_pb2.DictionaryDescription() + text_format.Merge( + 'special_ids: < unk: 0 > ' + 'vocabulary: < ' + ' index: < token: "a" token: "b" token: "c" token: "d" >' + '>', + dictionary_description) + dictionary = dictionary_ops.Dictionary.from_dictionary_description( + dictionary_description) + self.assertListEqual([b'a', b'b', b'c', b'd'], dictionary.tokens) + + +if __name__ == '__main__': + # Required since the test still relies on v1 Session.run behavior. + tf.compat.v1.disable_v2_behavior() + tf.test.main() |