diff options
Diffstat (limited to 'fcp/tensorflow/make_external_dataset_test_graph.py')
-rw-r--r-- | fcp/tensorflow/make_external_dataset_test_graph.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/fcp/tensorflow/make_external_dataset_test_graph.py b/fcp/tensorflow/make_external_dataset_test_graph.py new file mode 100644 index 0000000..0bb405b --- /dev/null +++ b/fcp/tensorflow/make_external_dataset_test_graph.py @@ -0,0 +1,58 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prints a GraphDef to stdout (for testing ExternalDataset).""" + +import argparse +import numpy as np +import tensorflow.compat.v1 as tf + +from fcp.tensorflow import external_dataset + + +def _ParseSingleExample(p): + # parse_example doesn't like scalars, so we reshape with [-1]. + features = tf.parse_example( + tf.reshape(p, [-1]), {"val": tf.FixedLenFeature([], dtype=tf.int64)}) + return features["val"] + + +def MakeGraph(): + """Makes a GraphDef.""" + + graph = tf.Graph() + + with graph.as_default(): + serialized_examples = external_dataset.ExternalDataset( + token=tf.placeholder(name="token", dtype=tf.string), + selector=tf.placeholder(name="selector", dtype=tf.string)) + + examples = serialized_examples.map(_ParseSingleExample) + + total = examples.reduce(np.int64(0), lambda x, y: x + y) + total = tf.identity(total, name="total") + + return graph + + +def _ParseArgs(): + parser = argparse.ArgumentParser() + parser.add_argument("--output", required=True, type=argparse.FileType("w")) + return parser.parse_args() + +if __name__ == "__main__": + args = _ParseArgs() + with args.output: + graph_def = MakeGraph().as_graph_def() + args.output.write(str(graph_def)) |