aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTony Mak <tonymak@google.com>2021-02-17 12:01:24 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2021-02-17 12:01:24 +0000
commit699e7bd5f0f6f12005fc96592c5b82e692dcfc56 (patch)
treedbbe4fa5f2c1115321652b1376792217f94739cb
parented73f2f75b75da3675b0a984dbfd195702855ed1 (diff)
parent815025dfce13a988896874b2f182de7c0e6cbad1 (diff)
downloadtflite-support-699e7bd5f0f6f12005fc96592c5b82e692dcfc56.tar.gz
Import platform/external/tflite-support am: ff3e0f1735 am: 815025dfce
Original change: https://android-review.googlesource.com/c/platform/external/tflite-support/+/1590212 MUST ONLY BE SUBMITTED BY AUTOMERGER Change-Id: I8b292b6b90a02c93617557e2822e8251c36affec
-rw-r--r--.bazelrc170
-rw-r--r--BUILD5
-rw-r--r--LICENSE203
-rw-r--r--METADATA19
-rw-r--r--MODULE_LICENSE_APACHE20
-rw-r--r--README.md62
-rw-r--r--WORKSPACE384
-rw-r--r--tensorflow_lite_support/BUILD38
-rw-r--r--tensorflow_lite_support/cc/BUILD25
-rw-r--r--tensorflow_lite_support/cc/common.cc35
-rw-r--r--tensorflow_lite_support/cc/common.h167
-rw-r--r--tensorflow_lite_support/cc/port/BUILD73
-rw-r--r--tensorflow_lite_support/cc/port/benchmark.h21
-rw-r--r--tensorflow_lite_support/cc/port/build_defs.bzl30
-rw-r--r--tensorflow_lite_support/cc/port/default/BUILD50
-rw-r--r--tensorflow_lite_support/cc/port/default/status_macros.h215
-rw-r--r--tensorflow_lite_support/cc/port/default/statusor.cc65
-rw-r--r--tensorflow_lite_support/cc/port/default/statusor.h574
-rw-r--r--tensorflow_lite_support/cc/port/default/statusor_internals.h409
-rw-r--r--tensorflow_lite_support/cc/port/default/tflite_wrapper.cc60
-rw-r--r--tensorflow_lite_support/cc/port/default/tflite_wrapper.h82
-rw-r--r--tensorflow_lite_support/cc/port/gmock.h21
-rw-r--r--tensorflow_lite_support/cc/port/gtest.h21
-rw-r--r--tensorflow_lite_support/cc/port/integral_types.h46
-rw-r--r--tensorflow_lite_support/cc/port/status_macros.h21
-rw-r--r--tensorflow_lite_support/cc/port/statusor.h20
-rw-r--r--tensorflow_lite_support/cc/port/tflite_wrapper.h21
-rw-r--r--tensorflow_lite_support/cc/task/README.md384
-rw-r--r--tensorflow_lite_support/cc/task/core/BUILD156
-rw-r--r--tensorflow_lite_support/cc/task/core/base_task_api.h144
-rw-r--r--tensorflow_lite_support/cc/task/core/category.h44
-rw-r--r--tensorflow_lite_support/cc/task/core/external_file_handler.cc194
-rw-r--r--tensorflow_lite_support/cc/task/core/external_file_handler.h94
-rw-r--r--tensorflow_lite_support/cc/task/core/proto/BUILD27
-rw-r--r--tensorflow_lite_support/cc/task/core/proto/external_file.proto67
-rw-r--r--tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h20
-rw-r--r--tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt2
-rw-r--r--tensorflow_lite_support/cc/task/core/task_api_factory.h100
-rw-r--r--tensorflow_lite_support/cc/task/core/task_utils.cc66
-rw-r--r--tensorflow_lite_support/cc/task/core/task_utils.h182
-rw-r--r--tensorflow_lite_support/cc/task/core/tflite_engine.cc297
-rw-r--r--tensorflow_lite_support/cc/task/core/tflite_engine.h245
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/BUILD118
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc198
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h105
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc70
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h60
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc467
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h181
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc89
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h72
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc30
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h43
-rw-r--r--tensorflow_lite_support/cc/task/text/qa/BUILD61
-rw-r--r--tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc79
-rw-r--r--tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h78
-rw-r--r--tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc393
-rw-r--r--tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h170
-rw-r--r--tensorflow_lite_support/cc/task/text/qa/question_answerer.h65
-rw-r--r--tensorflow_lite_support/cc/task/vision/BUILD108
-rw-r--r--tensorflow_lite_support/cc/task/vision/core/BUILD81
-rw-r--r--tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h270
-rw-r--r--tensorflow_lite_support/cc/task/vision/core/classification_head.cc114
-rw-r--r--tensorflow_lite_support/cc/task/vision/core/classification_head.h110
-rw-r--r--tensorflow_lite_support/cc/task/vision/core/frame_buffer.cc179
-rw-r--r--tensorflow_lite_support/cc/task/vision/core/frame_buffer.h296
-rw-r--r--tensorflow_lite_support/cc/task/vision/core/label_map_item.cc128
-rw-r--r--tensorflow_lite_support/cc/task/vision/core/label_map_item.h95
-rw-r--r--tensorflow_lite_support/cc/task/vision/image_classifier.cc572
-rw-r--r--tensorflow_lite_support/cc/task/vision/image_classifier.h182
-rw-r--r--tensorflow_lite_support/cc/task/vision/image_segmenter.cc427
-rw-r--r--tensorflow_lite_support/cc/task/vision/image_segmenter.h172
-rw-r--r--tensorflow_lite_support/cc/task/vision/object_detector.cc549
-rw-r--r--tensorflow_lite_support/cc/task/vision/object_detector.h186
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/BUILD208
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto30
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h19
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/class.proto36
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h20
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/classifications.proto35
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h22
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/detections.proto53
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h23
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto67
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h22
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto61
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h22
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto62
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h22
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/segmentations.proto109
-rw-r--r--tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h19
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/BUILD109
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc428
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h143
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc619
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h292
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h88
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc254
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h93
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc1499
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h76
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc225
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/score_calibration.h146
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/BUILD191
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml19
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc108
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h149
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc87
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc125
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h59
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc64
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h74
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/tokenizer.h55
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc86
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h36
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc136
-rw-r--r--tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h41
-rw-r--r--tensorflow_lite_support/cc/utils/BUILD32
-rw-r--r--tensorflow_lite_support/cc/utils/common_utils.cc96
-rw-r--r--tensorflow_lite_support/cc/utils/common_utils.h49
-rw-r--r--tensorflow_lite_support/cc/utils/jni_utils.cc100
-rw-r--r--tensorflow_lite_support/cc/utils/jni_utils.h91
-rw-r--r--tensorflow_lite_support/codegen/BUILD86
-rw-r--r--tensorflow_lite_support/codegen/README.md13
-rw-r--r--tensorflow_lite_support/codegen/android_java_generator.cc1017
-rw-r--r--tensorflow_lite_support/codegen/android_java_generator.h116
-rw-r--r--tensorflow_lite_support/codegen/code_generator.cc179
-rw-r--r--tensorflow_lite_support/codegen/code_generator.h80
-rw-r--r--tensorflow_lite_support/codegen/code_generator_test.cc126
-rw-r--r--tensorflow_lite_support/codegen/metadata_helper.cc100
-rw-r--r--tensorflow_lite_support/codegen/metadata_helper.h51
-rw-r--r--tensorflow_lite_support/codegen/python/BUILD37
-rw-r--r--tensorflow_lite_support/codegen/python/codegen.py104
-rw-r--r--tensorflow_lite_support/codegen/python/codegen_lib.cc49
-rw-r--r--tensorflow_lite_support/codegen/utils.cc194
-rw-r--r--tensorflow_lite_support/codegen/utils.h127
-rw-r--r--tensorflow_lite_support/codegen/utils_test.cc97
-rw-r--r--tensorflow_lite_support/custom_ops/BUILD43
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/BUILD146
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ngrams.cc208
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ngrams.h31
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc31
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h34
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver_wrapper.cc29
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc293
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ngrams_test.py266
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ragged/BUILD81
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD27
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ragged/py/pywrap_tflite_registerer.cc35
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.cc31
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h25
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc192
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc155
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc690
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc283
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD389
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/config.fbs25
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config.fbs43
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h120
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc81
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h41
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_test.cc78
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config.fbs52
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc197
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h52
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl86
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.cc63
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h50
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc90
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc239
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h52
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc167
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.cc34
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h25
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h43
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h31
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_op.cc94
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc100
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h31
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc119
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc129
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/testdata/sentencepiece.modelbin0 -> 330106 bytes
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h66
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_flex_delegate.tflitebin0 -> 23912 bytes
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_1d_input.tflitebin0 -> 688 bytes
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_2d_input.tflitebin0 -> 776 bytes
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_tensor.tflitebin0 -> 600 bytes
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc224
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h31
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc32
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h34
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver_wrapper.cc29
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.cc189
-rw-r--r--tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py168
-rw-r--r--tensorflow_lite_support/custom_ops/python/BUILD61
-rw-r--r--tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py57
-rw-r--r--tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py125
-rw-r--r--tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py251
-rw-r--r--tensorflow_lite_support/custom_ops/python/tflite_text_api.py126
-rw-r--r--tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflitebin0 -> 255136 bytes
-rw-r--r--tensorflow_lite_support/custom_ops/tf_configure.sh60
-rw-r--r--tensorflow_lite_support/custom_ops/tflite_inference_main.cc105
-rw-r--r--tensorflow_lite_support/examples/task/text/desktop/BUILD68
-rw-r--r--tensorflow_lite_support/examples/task/text/desktop/README.md134
-rw-r--r--tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc77
-rw-r--r--tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc81
-rw-r--r--tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc112
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/BUILD68
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/README.md180
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.pngbin0 -> 524248 bytes
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpgbin0 -> 83380 bytes
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpgbin0 -> 41901 bytes
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.pngbin0 -> 1038 bytes
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpgbin0 -> 55282 bytes
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc173
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc201
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc251
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD22
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc94
-rw-r--r--tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h58
-rw-r--r--tensorflow_lite_support/ios/BUILD48
-rw-r--r--tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template44
-rw-r--r--tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt3
-rw-r--r--tensorflow_lite_support/ios/ios.bzl30
-rw-r--r--tensorflow_lite_support/ios/task/text/apis/TFLTaskText.h17
-rw-r--r--tensorflow_lite_support/ios/task/text/apis/framework.modulemap4
-rw-r--r--tensorflow_lite_support/ios/task/text/nlclassifier/BUILD125
-rw-r--r--tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h51
-rw-r--r--tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m60
-rw-r--r--tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h86
-rw-r--r--tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m79
-rw-r--r--tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m61
-rw-r--r--tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift45
-rw-r--r--tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m65
-rw-r--r--tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.swift58
-rw-r--r--tensorflow_lite_support/ios/task/text/qa/BUILD71
-rw-r--r--tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h74
-rw-r--r--tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m71
-rw-r--r--tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m72
-rw-r--r--tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.swift63
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/BUILD106
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h38
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm57
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h33
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm45
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h39
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h38
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm41
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift50
-rw-r--r--tensorflow_lite_support/ios/text/tokenizers/Tests/TFLSentencepieceTokenizerTest.swift37
-rw-r--r--tensorflow_lite_support/ios/utils/BUILD15
-rw-r--r--tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h23
-rw-r--r--tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm23
-rw-r--r--tensorflow_lite_support/java/AndroidManifest.xml5
-rw-r--r--tensorflow_lite_support/java/BUILD78
-rw-r--r--tensorflow_lite_support/java/README.md38
-rw-r--r--tensorflow_lite_support/java/debug_version_script.lds5
-rw-r--r--tensorflow_lite_support/java/default_version_script.lds12
-rw-r--r--tensorflow_lite_support/java/jni/BUILD48
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java184
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java31
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java23
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java82
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java184
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java27
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java68
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java55
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java40
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java160
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java41
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java80
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java244
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java212
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java55
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java146
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java43
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java198
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java90
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java312
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java89
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java125
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java103
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java75
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java95
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java64
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java224
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java74
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java69
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java285
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java430
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java115
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java121
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD22
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java91
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java165
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java117
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml5
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD37
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD79
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java142
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java257
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD33
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java195
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java58
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java32
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml5
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD41
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml5
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD40
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java46
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java453
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml5
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD40
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java42
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java452
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml5
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD41
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java88
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java377
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java145
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java82
-rw-r--r--tensorflow_lite_support/java/src/native/task/core/BUILD16
-rw-r--r--tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc27
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/BUILD34
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD63
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD31
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc74
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc135
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc56
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h33
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/qa/BUILD30
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc127
-rw-r--r--tensorflow_lite_support/java/src/native/task/vision/BUILD59
-rw-r--r--tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD35
-rw-r--r--tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc234
-rw-r--r--tensorflow_lite_support/java/src/native/task/vision/detector/BUILD36
-rw-r--r--tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc228
-rw-r--r--tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc85
-rw-r--r--tensorflow_lite_support/java/src/native/task/vision/jni_utils.h38
-rw-r--r--tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD34
-rw-r--r--tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc238
-rw-r--r--tensorflow_lite_support/metadata/BUILD51
-rw-r--r--tensorflow_lite_support/metadata/README.md15
-rw-r--r--tensorflow_lite_support/metadata/build_defs.bzl43
-rw-r--r--tensorflow_lite_support/metadata/cc/BUILD53
-rw-r--r--tensorflow_lite_support/metadata/cc/metadata_extractor.cc366
-rw-r--r--tensorflow_lite_support/metadata/cc/metadata_extractor.h157
-rw-r--r--tensorflow_lite_support/metadata/cc/metadata_parser.h.template28
-rw-r--r--tensorflow_lite_support/metadata/cc/metadata_version.cc302
-rw-r--r--tensorflow_lite_support/metadata/cc/metadata_version.h38
-rw-r--r--tensorflow_lite_support/metadata/cc/python/BUILD22
-rw-r--r--tensorflow_lite_support/metadata/cc/python/metadata_version.cc55
-rw-r--r--tensorflow_lite_support/metadata/flatbuffers_lib/BUILD22
-rw-r--r--tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc59
-rw-r--r--tensorflow_lite_support/metadata/metadata_schema.fbs686
-rw-r--r--tensorflow_lite_support/opensource/opensource_only.files36
-rw-r--r--tensorflow_lite_support/tools/BUILD20
-rw-r--r--tensorflow_lite_support/tools/build_rules/expand_template.bzl50
-rw-r--r--tensorflow_lite_support/tools/ci_build/build_all.sh42
-rw-r--r--tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh114
-rw-r--r--tensorflow_lite_support/tools/ci_build/common.sh96
-rw-r--r--tensorflow_lite_support/tools/ci_build/common_win.bat29
-rw-r--r--tensorflow_lite_support/tools/ci_build/update_version.py120
-rw-r--r--tensorflow_lite_support/tools/pip_package/BUILD55
-rw-r--r--tensorflow_lite_support/tools/pip_package/MANIFEST.in9
-rw-r--r--tensorflow_lite_support/tools/pip_package/README1
-rwxr-xr-xtensorflow_lite_support/tools/pip_package/build_pip_package.sh232
-rw-r--r--tensorflow_lite_support/tools/pip_package/setup.py154
-rw-r--r--tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py33
-rw-r--r--tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py28
-rw-r--r--tensorflow_lite_support/tools/zip_files.py41
-rw-r--r--third_party/BUILD1
-rw-r--r--third_party/android/BUILD1
-rw-r--r--third_party/android/android.bzl.tpl9
-rw-r--r--third_party/android/android_configure.BUILD.tpl0
-rw-r--r--third_party/android/android_configure.bzl95
-rw-r--r--third_party/com_google_absl.BUILD5
-rw-r--r--third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff14
-rw-r--r--third_party/com_google_protobuf_fixes.diff140
-rw-r--r--third_party/darts_clone.BUILD15
-rw-r--r--third_party/fft2d/BUILD48
-rw-r--r--third_party/fft2d/LICENSE3
-rw-r--r--third_party/fft2d/fft.h36
-rw-r--r--third_party/fft2d/fft2d.BUILD45
-rw-r--r--third_party/fft2d/fft2d.h36
-rw-r--r--third_party/flatbuffers/BUILD1
-rw-r--r--third_party/flatbuffers/BUILD.bazel140
-rw-r--r--third_party/flatbuffers/build_defs.bzl617
-rw-r--r--third_party/flatbuffers/workspace.bzl19
-rw-r--r--third_party/gflags/BUILD1
-rw-r--r--third_party/gflags/fix_android_pthread_link.patch32
-rw-r--r--third_party/gflags/workspace.bzl16
-rw-r--r--third_party/google_toolbox_for_mac.BUILD22
-rw-r--r--third_party/icu.BUILD97
-rw-r--r--third_party/libyuv.BUILD25
-rw-r--r--third_party/libzip.BUILD189
-rw-r--r--third_party/py/BUILD0
-rw-r--r--third_party/py/BUILD.tpl31
-rw-r--r--third_party/py/python_configure.bzl71
-rw-r--r--third_party/pybind11.BUILD25
-rw-r--r--third_party/python_runtime/BUILD8
-rw-r--r--third_party/repo.bzl152
-rw-r--r--third_party/six.BUILD14
-rw-r--r--third_party/stblib.BUILD26
-rw-r--r--third_party/tensorflow/BUILD1
-rw-r--r--third_party/tensorflow/BUILD.tpl18
-rw-r--r--third_party/tensorflow/tf_configure.bzl224
-rw-r--r--third_party/tensorflow_lite_ios_build.patch40
-rw-r--r--third_party/tensorflow_text_remove_tf_deps.patch32
-rw-r--r--third_party/toolchains/java/BUILD18
-rw-r--r--third_party/utf.BUILD38
-rw-r--r--third_party/zlib.BUILD39
412 files changed, 43507 insertions, 0 deletions
diff --git a/.bazelrc b/.bazelrc
new file mode 100644
index 00000000..bdc2c07d
--- /dev/null
+++ b/.bazelrc
@@ -0,0 +1,170 @@
+# This file is based on tensorflow's (v2.2.0) .bazelrc found here:
+# https://github.com/tensorflow/tensorflow/blob/v2.2.0/.bazelrc
+
+# Sets the default Apple platform to macOS.
+
+build --apple_platform_type=macos
+
+# Enable using platform specific build settings
+build --enable_platform_specific_config
+
+# Flag to enable remote config. Required starting from TF 2.2.
+common --experimental_repo_remote_exec
+
+# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1
+build --java_toolchain=//third_party/toolchains/java:tf_java_toolchain
+build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
+
+# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
+build:android --copt=-w
+build:linux --copt=-w
+build:macos --copt=-w
+build:windows --copt=/w
+
+# Android workspace configurations. Should be replaced by an interative configure in the future.
+build --action_env ANDROID_NDK_HOME
+build --action_env ANDROID_NDK_API_LEVEL
+build --action_env ANDROID_BUILD_TOOLS_VERSION
+build --action_env ANDROID_SDK_API_LEVEL
+build --action_env ANDROID_SDK_HOME
+
+# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the
+# target CPU to build transient dependencies correctly. See
+# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu
+
+build:android --crosstool_top=//external:android/crosstool
+build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
+build:android_arm --config=android
+build:android_arm --cpu=armeabi-v7a
+build:android_arm --fat_apk_cpu=armeabi-v7a
+build:android_arm64 --config=android
+build:android_arm64 --cpu=arm64-v8a
+build:android_arm64 --fat_apk_cpu=arm64-v8a
+build:android_x86 --config=android
+build:android_x86 --cpu=x86
+build:android_x86 --fat_apk_cpu=x86
+build:android_x86_64 --config=android
+build:android_x86_64 --cpu=x86_64
+build:android_x86_64 --fat_apk_cpu=x86_64
+
+# iOS configs for each architecture and the fat binary builds.
+build:ios --apple_platform_type=ios
+build:ios --apple_bitcode=embedded --copt=-fembed-bitcode
+build:ios --copt=-Wno-c++11-narrowing
+build:ios_armv7 --config=ios
+build:ios_armv7 --cpu=ios_armv7
+build:ios_arm64 --config=ios
+build:ios_arm64 --cpu=ios_arm64
+build:ios_x86_64 --config=ios
+build:ios_x86_64 --cpu=ios_x86_64
+build:ios_fat --config=ios
+build:ios_fat --ios_multi_cpus=armv7,arm64,x86_64
+
+# By default, build TF in C++ 14 mode.
+build:android --cxxopt=-std=c++14
+build:android --host_cxxopt=-std=c++14
+build:ios --cxxopt=-std=c++14
+build:ios --host_cxxopt=-std=c++14
+build:linux --cxxopt=-std=c++14
+build:linux --host_cxxopt=-std=c++14
+build:macos --cxxopt=-std=c++14
+build:macos --host_cxxopt=-std=c++14
+build:windows --cxxopt=/std:c++14
+build:windows --host_cxxopt=/std:c++14
+
+# Config to use a mostly-static build and disable modular op registration
+# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
+# By default, TensorFlow will build with a dependence on
+# //tensorflow:libtensorflow_framework.so.
+build:monolithic --define framework_shared_object=false
+
+# For projects which use TensorFlow as part of a Bazel build process, putting
+# nothing in a bazelrc will default to a monolithic build. The following line
+# opts in to modular op registration support by default.
+build --define framework_shared_object=true
+
+# ASAN build
+build:asan --strip=never
+build:asan --copt -fsanitize=address
+build:asan --copt -DADDRESS_SANITIZER
+build:asan --copt -O1
+build:asan --copt -g
+build:asan --copt -fno-omit-frame-pointer
+build:asan --linkopt -fsanitize=address
+
+# Flags for open source build, always set to be true.
+build --define open_source_build=true
+test --define open_source_build=true
+
+# dbg config, as a shorthand for '--config=opt -c dbg'
+build:dbg --config=opt -c dbg
+# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
+build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
+# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
+build:dbg --copt -DDEBUG_BUILD
+
+build --define=use_fast_cpp_protos=true
+build --define=allow_oversize_protos=true
+
+build --spawn_strategy=standalone
+build -c opt
+
+# Adding "--cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" creates parity with TF
+# compilation options. It also addresses memory use due to
+# copy-on-write semantics of std::strings of the older ABI.
+build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0
+
+# Make Bazel print out all options from rc files.
+build --announce_rc
+
+# Other build flags.
+build --define=grpc_no_ares=true
+
+# See https://github.com/bazelbuild/bazel/issues/7362 for information on what
+# --incompatible_remove_legacy_whole_archive flag does.
+# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate
+# Tensorflow to the default, however test coverage wasn't enough to catch the
+# errors.
+# There is ongoing work on Bazel team's side to provide support for transitive
+# shared libraries. As part of migrating to transitive shared libraries, we
+# hope to provide a better mechanism for control over symbol exporting, and
+# then tackle this issue again.
+#
+# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library
+# archives in -whole_archive -no_whole_archive.
+build --noincompatible_remove_legacy_whole_archive
+
+# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
+# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
+# https://github.com/tensorflow/community/pull/179
+build --noincompatible_prohibit_aapt1
+
+# Build TF with C++ 17 features.
+build:c++17 --cxxopt=-std=c++1z
+build:c++17 --cxxopt=-stdlib=libc++
+build:c++1z --config=c++17
+
+# Enable using platform specific build settings, except when cross-compiling for
+# mobile platforms.
+build --enable_platform_specific_config
+build:android --noenable_platform_specific_config
+build:ios --noenable_platform_specific_config
+
+# Suppress all warning messages.
+build:short_logs --output_filter=DONT_MATCH_ANYTHING
+build:verbose_logs --output_filter=
+build --config=short_logs
+
+# Options to build TensorFlow 1.x or 2.x.
+build:v1 --define=tf_api_version=1
+build:v2 --define=tf_api_version=2
+build:v1 --action_env=TF2_BEHAVIOR=0
+build:v2 --action_env=TF2_BEHAVIOR=1
+build --config=v2
+test --config=v2
+
+# Options from ./configure
+try-import %workspace%/.tf_configure.bazelrc
+
+# Put user-specific options in .bazelrc.user
+try-import %workspace%/.bazelrc.user
diff --git a/BUILD b/BUILD
new file mode 100644
index 00000000..abfcbddf
--- /dev/null
+++ b/BUILD
@@ -0,0 +1,5 @@
+exports_files(
+ [
+ "LICENSE",
+ ],
+)
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 00000000..786bd073
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2020 The TensorFlow Authors. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/METADATA b/METADATA
new file mode 100644
index 00000000..fb2db712
--- /dev/null
+++ b/METADATA
@@ -0,0 +1,19 @@
+name: "tflite-support"
+description:
+ "TFLite Support is a toolkit that helps users to develop ML and deploy "
+ "TFLite models onto mobile devices. It works cross-Platform and is "
+ "supported on Java, C++ (WIP), and Swift (WIP)."
+
+third_party {
+ url {
+ type: HOMEPAGE
+ value: "https://github.com/tensorflow/tflite-support"
+ }
+ url {
+ type: GIT
+ value: "https://github.com/tensorflow/tflite-support"
+ }
+ version: "v0.1.0"
+ last_upgrade_date { year: 2021 month: 1 day: 14 }
+ license_type: NOTICE
+} \ No newline at end of file
diff --git a/MODULE_LICENSE_APACHE2 b/MODULE_LICENSE_APACHE2
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/MODULE_LICENSE_APACHE2
diff --git a/README.md b/README.md
new file mode 100644
index 00000000..67d4a8fd
--- /dev/null
+++ b/README.md
@@ -0,0 +1,62 @@
+# TensorFlow Lite Support
+
+TFLite Support is a toolkit that helps users to develop ML and deploy TFLite
+models onto mobile devices. It works cross-Platform and is supported on Java,
+C++ (WIP), and Swift (WIP). The TFLite Support project consists of the following
+major components:
+
+* **TFLite Support Library**: a cross-platform library that helps to
+ deploy TFLite models onto mobile devices.
+* **TFLite Model Metadata**: (metadata populator and metadata extractor
+ library): includes both human and machine readable information about what a
+ model does and how to use the model.
+* **TFLite Support Codegen Tool**: an executable that generates model wrapper
+ automatically based on the Support Library and the metadata.
+* **TFLite Support Task Library**: a flexible and ready-to-use library for
+ common machine learning model types, such as classification and detection,
+ client can also build their own native/Android/iOS inference API on Task
+ Library infra.
+
+TFLite Support library serves different tiers of deployment requirements from
+easy onboarding to fully customizable. There are three major use cases that
+TFLite Support targets at:
+
+* **Provide ready-to-use APIs for users to interact with the model**. \
+ This is achieved by the TFLite Support Codegen tool, where users can get the
+ model interface (contains ready-to-use APIs) simply by passing the model to
+ the codegen tool. The automatic codegen strategy is designed based on the
+ TFLite metadata.
+
+* **Provide optimized model interface for popular ML tasks**. \
+ The model interfaces provided by the TFLite Support Task Library are
+ specifically optimized compared to the codegen version in terms of both
+ usability and performance. Users can also swap their own custom models with
+ the default models in each task.
+
+* **Provide the flexibility to customize model interface and build inference
+ pipelines**. \
+ The TFLite Support Util Library contains varieties of util methods and data
+ structures to perform pre/post processing and data conversion. It is also
+ designed to match the behavior of TensorFlow modules, such as TF.Image and
+ TF.text, ensuring consistency from training to inferencing.
+
+See the
+[documentation on tensorflow.org](https://www.tensorflow.org/lite/inference_with_metadata/overview)
+for more instruction and examples.
+
+## Build Instructions
+
+We use Bazel to build the project. When you're building the Java (Android)
+Utils, you need to set up following env variables correctly:
+
+* `ANDROID_NDK_HOME`
+* `ANDROID_SDK_HOME`
+* `ANDROID_NDK_API_LEVEL`
+* `ANDROID_SDK_API_LEVEL`
+* `ANDROID_BUILD_TOOLS_VERSION`
+
+## Contact us
+
+Let us know what you think about TFLite Support by creating a
+[new Github issue](https://github.com/tensorflow/tflite-support/issues/new), or
+email us at tflite-support-team@google.com.
diff --git a/WORKSPACE b/WORKSPACE
new file mode 100644
index 00000000..21948710
--- /dev/null
+++ b/WORKSPACE
@@ -0,0 +1,384 @@
+workspace(name = "org_tensorflow_lite_support")
+
+load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external")
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+load("@//third_party/py:python_configure.bzl", "python_configure")
+
+http_archive(
+ name = "io_bazel_rules_closure",
+ sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
+ strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
+ "https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
+ ],
+)
+
+# Apple and Swift rules.
+# https://github.com/bazelbuild/rules_apple/releases
+http_archive(
+ name = "build_bazel_rules_apple",
+ sha256 = "ee9e6073aeb5a65c100cb9c44b0017c937706a4ae03176e14a7e78620a198079",
+ strip_prefix = "rules_apple-5131f3d46794bf227d296c82f30c2499c9de3c5b",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_apple/archive/5131f3d46794bf227d296c82f30c2499c9de3c5b.tar.gz",
+ "https://github.com/bazelbuild/rules_apple/archive/5131f3d46794bf227d296c82f30c2499c9de3c5b.tar.gz",
+ ],
+)
+
+# https://github.com/bazelbuild/rules_swift/releases
+http_archive(
+ name = "build_bazel_rules_swift",
+ sha256 = "d0833bc6dad817a367936a5f902a0c11318160b5e80a20ece35fb85a5675c886",
+ strip_prefix = "rules_swift-3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_swift/archive/3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8.tar.gz",
+ "https://github.com/bazelbuild/rules_swift/archive/3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8.tar.gz",
+ ],
+)
+
+# tf-nightly-20200810
+http_archive(
+ name = "org_tensorflow",
+ sha256 = "26c833b7e1873936379e810a39d14700281125257ddda8cd822c89111db6f6ae",
+ strip_prefix = "tensorflow-2.4.0",
+ urls = [
+ "https://github.com/tensorflow/tensorflow/archive/v2.4.0.tar.gz",
+ ],
+ patches = ["@//third_party:tensorflow_lite_ios_build.patch"],
+ patch_args = ["-p1"],
+)
+
+# Set up dependencies. Need to do this before set up TF so that our modification
+# could take effects.
+load("//third_party:repo.bzl", "third_party_http_archive")
+
+# Use our patched gflags which fixes a linking issue.
+load("//third_party/gflags:workspace.bzl", gflags = "repo")
+gflags()
+
+third_party_http_archive(
+ name = "pybind11",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.6.0.tar.gz",
+ "https://github.com/pybind/pybind11/archive/v2.6.0.tar.gz",
+ ],
+ sha256 = "90b705137b69ee3b5fc655eaca66d0dc9862ea1759226f7ccd3098425ae69571",
+ strip_prefix = "pybind11-2.6.0",
+ build_file = "//third_party:pybind11.BUILD",
+)
+
+http_archive(
+ name = "absl_py",
+ sha256 = "603febc9b95a8f2979a7bdb77d2f5e4d9b30d4e0d59579f88eba67d4e4cc5462",
+ strip_prefix = "abseil-py-pypi-v0.9.0",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-py/archive/pypi-v0.9.0.tar.gz",
+ "https://github.com/abseil/abseil-py/archive/pypi-v0.9.0.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "six_archive",
+ build_file = "//third_party:six.BUILD",
+ sha256 = "d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73",
+ strip_prefix = "six-1.12.0",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz",
+ "https://pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "com_google_sentencepiece",
+ strip_prefix = "sentencepiece-1.0.0",
+ sha256 = "c05901f30a1d0ed64cbcf40eba08e48894e1b0e985777217b7c9036cac631346",
+ urls = [
+ "https://github.com/google/sentencepiece/archive/1.0.0.zip",
+ ],
+)
+
+http_archive(
+ name = "org_tensorflow_text",
+ sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8",
+ strip_prefix = "text-2.2.0",
+ urls = [
+ "https://github.com/tensorflow/text/archive/v2.2.0.zip",
+ ],
+ patches = ["@//third_party:tensorflow_text_remove_tf_deps.patch"],
+ patch_args = ["-p1"],
+ repo_mapping = {"@com_google_re2": "@com_googlesource_code_re2"},
+)
+
+http_archive(
+ name = "com_googlesource_code_re2",
+ sha256 = "d070e2ffc5476c496a6a872a6f246bfddce8e7797d6ba605a7c8d72866743bf9",
+ strip_prefix = "re2-506cfa4bffd060c06ec338ce50ea3468daa6c814",
+ urls = [
+ "https://github.com/google/re2/archive/506cfa4bffd060c06ec338ce50ea3468daa6c814.tar.gz",
+ ],
+)
+
+# ABSL cpp library lts_2020_02_25
+# Needed for absl/status
+http_archive(
+ name = "com_google_absl",
+ build_file = "//third_party:com_google_absl.BUILD",
+ urls = [
+ "https://github.com/abseil/abseil-cpp/archive/20200225.tar.gz",
+ ],
+ # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved.
+ patches = [
+ "@//third_party:com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff"
+ ],
+ patch_args = [
+ "-p1",
+ ],
+ strip_prefix = "abseil-cpp-20200225",
+ sha256 = "728a813291bdec2aa46eab8356ace9f75ac2ed9dfe2df5ab603c4e6c09f1c353"
+)
+
+http_archive(
+ name = "com_google_glog",
+ sha256 = "1ee310e5d0a19b9d584a855000434bb724aa744745d5b8ab1855c85bff8a8e21",
+ strip_prefix = "glog-028d37889a1e80e8a07da1b8945ac706259e5fd8",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/glog/archive/028d37889a1e80e8a07da1b8945ac706259e5fd8.tar.gz",
+ "https://github.com/google/glog/archive/028d37889a1e80e8a07da1b8945ac706259e5fd8.tar.gz",
+ ],
+)
+
+
+http_archive(
+ name = "zlib",
+ build_file = "//third_party:zlib.BUILD",
+ sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
+ strip_prefix = "zlib-1.2.11",
+ urls = [
+ "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz",
+ "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15
+ ],
+)
+
+http_archive(
+ name = "org_libzip",
+ build_file = "//third_party:libzip.BUILD",
+ sha256 = "a5d22f0c87a2625450eaa5e10db18b8ee4ef17042102d04c62e311993a2ba363",
+ strip_prefix = "libzip-rel-1-5-1",
+ urls = [
+ # Bazel does not like the official download link at libzip.org,
+ # so use the GitHub release tag.
+ "https://mirror.bazel.build/github.com/nih-at/libzip/archive/rel-1-5-1.zip",
+ "https://github.com/nih-at/libzip/archive/rel-1-5-1.zip",
+ ],
+)
+
+http_archive(
+ name = "libyuv",
+ urls = ["https://chromium.googlesource.com/libyuv/libyuv/+archive/6d603ec3f57dafddc424ef895e5d903915e94ba6.tar.gz"],
+ # Adding the constrain of sha256 and strip_prefix will cause failure.
+ # It seems that the downloaded libyuv was different every time, so that
+ # the specified sha256 and strip_prefix cannot match.
+ # sha256 = "ce196c72858456baa8022fa4a0dc18b77d619265dbc0e3d58e25ad15ca402522",
+ # strip_prefix = "libyuv-6d603ec3f57dafddc424ef895e5d903915e94ba6",
+ build_file = "//third_party:libyuv.BUILD",
+)
+
+http_archive(
+ name = "stblib",
+ strip_prefix = "stb-b42009b3b9d4ca35bc703f5310eedc74f584be58",
+ sha256 = "13a99ad430e930907f5611325ec384168a958bf7610e63e60e2fd8e7b7379610",
+ urls = ["https://github.com/nothings/stb/archive/b42009b3b9d4ca35bc703f5310eedc74f584be58.tar.gz"],
+ build_file = "//third_party:stblib.BUILD",
+)
+
+http_archive(
+ name = "google_toolbox_for_mac",
+ url = "https://github.com/google/google-toolbox-for-mac/archive/v2.2.1.zip",
+ sha256 = "e3ac053813c989a88703556df4dc4466e424e30d32108433ed6beaec76ba4fdc",
+ strip_prefix = "google-toolbox-for-mac-2.2.1",
+ build_file = "@//third_party:google_toolbox_for_mac.BUILD",
+)
+
+http_archive(
+ name = "utf_archive",
+ build_file = "@//third_party:utf.BUILD",
+ sha256 = "262a902f622dcd28e05b8a4be10da0aa3899050d0be8f4a71780eed6b2ea65ca",
+ urls = [
+ "https://mirror.bazel.build/9fans.github.io/plan9port/unix/libutf.tgz",
+ "https://9fans.github.io/plan9port/unix/libutf.tgz",
+ ],
+)
+
+http_archive(
+ name = "icu",
+ strip_prefix = "icu-release-64-2",
+ sha256 = "dfc62618aa4bd3ca14a3df548cd65fe393155edd213e49c39f3a30ccd618fc27",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/unicode-org/icu/archive/release-64-2.zip",
+ "https://github.com/unicode-org/icu/archive/release-64-2.zip",
+ ],
+ build_file = "@//third_party:icu.BUILD",
+)
+
+http_archive(
+ name = "fft2d",
+ build_file = "@//third_party/fft2d:fft2d.BUILD",
+ sha256 = "5f4dabc2ae21e1f537425d58a49cdca1c49ea11db0d6271e2a4b27e9697548eb",
+ strip_prefix = "OouraFFT-1.0",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/petewarden/OouraFFT/archive/v1.0.tar.gz",
+ "https://github.com/petewarden/OouraFFT/archive/v1.0.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "darts_clone",
+ build_file = "@//third_party:darts_clone.BUILD",
+ sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
+ strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
+ urls = [
+ "https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
+ ],
+)
+
+http_archive(
+ name = "com_google_protobuf",
+ sha256 = "a79d19dcdf9139fa4b81206e318e33d245c4c9da1ffed21c87288ed4380426f9",
+ strip_prefix = "protobuf-3.11.4",
+ urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.11.4.tar.gz"],
+ patches = [
+ "@//third_party:com_google_protobuf_fixes.diff"
+ ],
+ patch_args = [
+ "-p1",
+ ],
+)
+
+# AutoValue 1.6+ shades Guava, Auto Common, and JavaPoet. That's OK
+# because none of these jars become runtime dependencies.
+java_import_external(
+ name = "com_google_auto_value",
+ jar_sha256 = "fd811b92bb59ae8a4cf7eb9dedd208300f4ea2b6275d726e4df52d8334aaae9d",
+ jar_urls = [
+ "https://mirror.bazel.build/repo1.maven.org/maven2/com/google/auto/value/auto-value/1.6/auto-value-1.6.jar",
+ "https://repo1.maven.org/maven2/com/google/auto/value/auto-value/1.6/auto-value-1.6.jar",
+ ],
+ licenses = ["notice"], # Apache 2.0
+ generated_rule_name = "processor",
+ exports = ["@com_google_auto_value_annotations"],
+ extra_build_file_content = "\n".join([
+ "java_plugin(",
+ " name = \"AutoAnnotationProcessor\",",
+ " output_licenses = [\"unencumbered\"],",
+ " processor_class = \"com.google.auto.value.processor.AutoAnnotationProcessor\",",
+ " tags = [\"annotation=com.google.auto.value.AutoAnnotation;genclass=${package}.AutoAnnotation_${outerclasses}${classname}_${methodname}\"],",
+ " deps = [\":processor\"],",
+ ")",
+ "",
+ "java_plugin(",
+ " name = \"AutoOneOfProcessor\",",
+ " output_licenses = [\"unencumbered\"],",
+ " processor_class = \"com.google.auto.value.processor.AutoOneOfProcessor\",",
+ " tags = [\"annotation=com.google.auto.value.AutoValue;genclass=${package}.AutoOneOf_${outerclasses}${classname}\"],",
+ " deps = [\":processor\"],",
+ ")",
+ "",
+ "java_plugin(",
+ " name = \"AutoValueProcessor\",",
+ " output_licenses = [\"unencumbered\"],",
+ " processor_class = \"com.google.auto.value.processor.AutoValueProcessor\",",
+ " tags = [\"annotation=com.google.auto.value.AutoValue;genclass=${package}.AutoValue_${outerclasses}${classname}\"],",
+ " deps = [\":processor\"],",
+ ")",
+ "",
+ "java_library(",
+ " name = \"com_google_auto_value\",",
+ " exported_plugins = [",
+ " \":AutoAnnotationProcessor\",",
+ " \":AutoOneOfProcessor\",",
+ " \":AutoValueProcessor\",",
+ " ],",
+ " exports = [\"@com_google_auto_value_annotations\"],",
+ ")",
+ ]),
+)
+
+# Auto value annotations
+java_import_external(
+ name = "com_google_auto_value_annotations",
+ jar_sha256 = "d095936c432f2afc671beaab67433e7cef50bba4a861b77b9c46561b801fae69",
+ jar_urls = [
+ "https://mirror.bazel.build/repo1.maven.org/maven2/com/google/auto/value/auto-value-annotations/1.6/auto-value-annotations-1.6.jar",
+ "https://repo1.maven.org/maven2/com/google/auto/value/auto-value-annotations/1.6/auto-value-annotations-1.6.jar",
+ ],
+ licenses = ["notice"], # Apache 2.0
+ neverlink = True,
+ default_visibility = ["@com_google_auto_value//:__pkg__"],
+)
+
+load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
+
+flatbuffers()
+# Set up TF.
+load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
+tf_workspace(tf_repo_name="@org_tensorflow")
+
+load("//third_party/tensorflow:tf_configure.bzl", "tf_configure")
+tf_configure(name = "local_config_tf")
+
+# TF submodule compilation doesn't take care of grpc deps. Do it manually here.
+load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+grpc_deps()
+
+load(
+ "@build_bazel_rules_apple//apple:repositories.bzl",
+ "apple_rules_dependencies",
+)
+apple_rules_dependencies()
+
+load(
+ "@build_bazel_apple_support//lib:repositories.bzl",
+ "apple_support_dependencies",
+)
+apple_support_dependencies()
+
+load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
+bazel_version_repository(name = "bazel_version")
+
+
+# Set up Android.
+load("//third_party/android:android_configure.bzl", "android_configure")
+android_configure(name="local_config_android")
+load("@local_config_android//:android.bzl", "android_workspace")
+android_workspace()
+
+python_configure(name = "local_config_python")
+
+
+# Maven dependencies.
+
+RULES_JVM_EXTERNAL_TAG = "3.2"
+
+http_archive(
+ name = "rules_jvm_external",
+ strip_prefix = "rules_jvm_external-%s" % RULES_JVM_EXTERNAL_TAG,
+ sha256 = "82262ff4223c5fda6fb7ff8bd63db8131b51b413d26eb49e3131037e79e324af",
+ url = "https://github.com/bazelbuild/rules_jvm_external/archive/%s.zip" % RULES_JVM_EXTERNAL_TAG,
+)
+
+load("@rules_jvm_external//:defs.bzl", "maven_install")
+
+maven_install(
+ artifacts = [
+ "androidx.annotation:annotation:aar:1.1.0",
+ ],
+ repositories = [
+ "https://jcenter.bintray.com",
+ "https://maven.google.com",
+ "https://dl.google.com/dl/android/maven2",
+ "https://repo1.maven.org/maven2",
+ ],
+ fetch_sources = True,
+ version_conflict_policy = "pinned",
+)
diff --git a/tensorflow_lite_support/BUILD b/tensorflow_lite_support/BUILD
new file mode 100644
index 00000000..f123f1f2
--- /dev/null
+++ b/tensorflow_lite_support/BUILD
@@ -0,0 +1,38 @@
+# TFLite Support is a toolkit that helps users to develop ML and deploy TFLite
+# models onto mobile devices.
+
+package(
+ default_visibility = ["//visibility:private"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(["LICENSE"])
+
+# LINT.IfChange
+package_group(
+ name = "users",
+ packages = [
+ # tensorflow_examples/... dep,
+ "//tensorflow_lite_support/...",
+ "//third_party/tensorflow_models/...",
+ ],
+)
+# Remove internal path from tensorflow_lite_support:users in the copybara file.
+# LINT.ThenChange(//tensorflow_lite_support/copy.bara.sky)
+
+# Config setting for determining if we are building for Android.
+config_setting(
+ name = "android",
+ values = {"crosstool_top": "//external:android/crosstool"},
+ visibility = ["//visibility:public"],
+)
+
+# Config setting for determining if we are building for macos.
+config_setting(
+ name = "macos",
+ values = {
+ "apple_platform_type": "macos",
+ "cpu": "darwin",
+ },
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow_lite_support/cc/BUILD b/tensorflow_lite_support/cc/BUILD
new file mode 100644
index 00000000..b19bfdec
--- /dev/null
+++ b/tensorflow_lite_support/cc/BUILD
@@ -0,0 +1,25 @@
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "common",
+ srcs = [
+ "common.cc",
+ ],
+ hdrs = ["common.h"],
+ deps = [
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:cord",
+ ],
+)
+
+config_setting(
+ name = "tflite_use_c_api",
+ values = {
+ "copt": "-DTFLITE_USE_C_API",
+ },
+ visibility = ["//tensorflow_lite_support:__subpackages__"],
+)
diff --git a/tensorflow_lite_support/cc/common.cc b/tensorflow_lite_support/cc/common.cc
new file mode 100644
index 00000000..47dd3bcc
--- /dev/null
+++ b/tensorflow_lite_support/cc/common.cc
@@ -0,0 +1,35 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/common.h"
+
+#include "absl/strings/cord.h"
+#include "absl/strings/str_cat.h"
+
+namespace tflite {
+namespace support {
+
+absl::Status CreateStatusWithPayload(absl::StatusCode canonical_code,
+ absl::string_view message,
+ TfLiteSupportStatus tfls_code) {
+ // NOTE: Ignores `message` if the canonical code is ok.
+ absl::Status status = absl::Status(canonical_code, message);
+ // NOTE: Does nothing if the canonical code is ok.
+ status.SetPayload(kTfLiteSupportPayload, absl::Cord(absl::StrCat(tfls_code)));
+ return status;
+}
+
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/common.h b/tensorflow_lite_support/cc/common.h
new file mode 100644
index 00000000..50c8dc40
--- /dev/null
+++ b/tensorflow_lite_support/cc/common.h
@@ -0,0 +1,167 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+
+namespace tflite {
+namespace support {
+
+// Name (aka type URL key) of the `absl::Status` payload which contains a
+// stringified `TfLiteSupportStatus` code (see below).
+constexpr absl::string_view kTfLiteSupportPayload =
+ "tflite::support::TfLiteSupportStatus";
+
+// Error codes for TensorFlow Lite Support (TFLS) C++ APIs.
+//
+// Such codes capture errors encountered in the TFLS layer. They complement all
+// the other type of errors that occur in the lower-level TF Lite codebase (see
+// `TfLiteStatus` codes).
+//
+// At runtime, such codes are meant to be attached (where applicable) to a
+// `absl::Status` in a key-value manner with `kTfLiteSupportPayload` as key and
+// stringifed error code as value (aka payload). This logic is encapsulated in
+// the `CreateStatusWithPayload` helper below for convenience.
+//
+// The returned status includes:
+// 1. The canonical error code (INVALID_ARGUMENT)
+// 2. The fine-grained error message ("Invalid metadata ...")
+// 3. The specific TFLS code as a payload (kMetadataInvalidSchemaVersionError)
+enum class TfLiteSupportStatus {
+ // Generic error codes.
+
+ // Success.
+ kOk = 0,
+ // Unspecified error.
+ kError = 1,
+ // Invalid argument specified.
+ kInvalidArgumentError = 2,
+ // Invalid FlatBuffer file or buffer specified.
+ kInvalidFlatBufferError = 3,
+
+ // File I/O error codes.
+
+ // No such file.
+ kFileNotFoundError = 100,
+ // Permission issue.
+ kFilePermissionDeniedError,
+ // I/O error when reading file.
+ kFileReadError,
+ // I/O error when mmap-ing file.
+ kFileMmapError,
+
+ // TensorFlow Lite metadata error codes.
+
+ // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer.
+ kMetadataInvalidSchemaVersionError = 200,
+ // No such associated file within metadata, or file has not been packed.
+ kMetadataAssociatedFileNotFoundError,
+ // ZIP I/O error when unpacking an associated file.
+ kMetadataAssociatedFileZipError,
+ // Inconsistency error between the metadata and actual TF Lite model.
+ // E.g.: number of labels and output tensor values differ.
+ kMetadataInconsistencyError,
+ // Invalid process units specified.
+ // E.g.: multiple ProcessUnits with the same type for a given tensor.
+ kMetadataInvalidProcessUnitsError,
+ // Inconsistency error with the number of labels.
+ // E.g.: label files for different locales have a different number of labels.
+ kMetadataNumLabelsMismatchError,
+ // Score calibration parameters parsing error.
+ // E.g.: too many parameters provided in the corresponding associated file.
+ kMetadataMalformedScoreCalibrationError,
+ // Unexpected number of subgraphs for the current task.
+ // E.g.: image classification expects a single subgraph.
+ kMetadataInvalidNumSubgraphsError,
+ // A given tensor requires NormalizationOptions but none were found.
+ // E.g.: float input tensor requires normalization to preprocess input images.
+ kMetadataMissingNormalizationOptionsError,
+ // Invalid ContentProperties specified.
+ // E.g. expected ImageProperties, got BoundingBoxProperties.
+ kMetadataInvalidContentPropertiesError,
+ // Metadata is mandatory but was not found.
+ // E.g. current task requires TFLite Model Metadata but none was found.
+ kMetadataNotFoundError,
+ // Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but
+ // none was found or it was empty.
+ // E.g. current task requires labels but none were found.
+ kMetadataMissingLabelsError,
+ // The ProcessingUnit for tokenizer is not correctly configured.
+ // E.g BertTokenizer doesn't have a valid vocab file associated.
+ kMetadataInvalidTokenizerError,
+
+ // Input tensor(s) error codes.
+
+ // Unexpected number of input tensors for the current task.
+ // E.g. current task expects a single input tensor.
+ kInvalidNumInputTensorsError = 300,
+ // Unexpected input tensor dimensions for the current task.
+ // E.g.: only 4D input tensors supported.
+ kInvalidInputTensorDimensionsError,
+ // Unexpected input tensor type for the current task.
+ // E.g.: current task expects a uint8 pixel image as input.
+ kInvalidInputTensorTypeError,
+ // Unexpected input tensor bytes size.
+ // E.g.: size in bytes does not correspond to the expected number of pixels.
+ kInvalidInputTensorSizeError,
+ // No correct input tensor found for the model.
+ // E.g.: input tensor name is not part of the text model's input tensors.
+ kInputTensorNotFoundError,
+
+ // Output tensor(s) error codes.
+
+ // Unexpected output tensor dimensions for the current task.
+ // E.g.: only a batch size of 1 is supported.
+ kInvalidOutputTensorDimensionsError = 400,
+ // Unexpected input tensor type for the current task.
+ // E.g.: multi-head model with different output tensor types.
+ kInvalidOutputTensorTypeError,
+ // No correct output tensor found for the model.
+ // E.g.: output tensor name is not part of the text model's output tensors.
+ kOutputTensorNotFoundError,
+ // Unexpected number of output tensors for the current task.
+ // E.g.: current task expects a single output tensor.
+ kInvalidNumOutputTensorsError,
+
+ // Image processing error codes.
+
+ // Unspecified image processing failures.
+ kImageProcessingError = 500,
+ // Unexpected input or output buffer metadata.
+ // E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees.
+ kImageProcessingInvalidArgumentError,
+ // Image processing operation failures.
+ // E.g. libyuv rotation failed for an unknown reason.
+ kImageProcessingBackendError,
+};
+
+// Convenience helper to create an `absl::Status` augmented with the
+// fine-grained `tfls_code` attached as payload under the
+// `kTfLiteSupportPayload` type URL key.
+//
+// This should only be used for non-ok codes since otherwise it does nothing
+// more than returning an object identical to an OK status. See `absl::Status`
+// for more details.
+absl::Status CreateStatusWithPayload(
+ absl::StatusCode canonical_code, absl::string_view message,
+ tflite::support::TfLiteSupportStatus tfls_code =
+ tflite::support::TfLiteSupportStatus::kError);
+
+} // namespace support
+} // namespace tflite
+#endif // TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_
diff --git a/tensorflow_lite_support/cc/port/BUILD b/tensorflow_lite_support/cc/port/BUILD
new file mode 100644
index 00000000..195d5a11
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/BUILD
@@ -0,0 +1,73 @@
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "statusor",
+ hdrs = [
+ "statusor.h",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc/port/default:statusor",
+ ],
+)
+
+cc_library(
+ name = "status_macros",
+ hdrs = [
+ "status_macros.h",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc/port/default:status_macros",
+ ],
+)
+
+cc_library(
+ name = "tflite_wrapper",
+ hdrs = ["tflite_wrapper.h"],
+ deps = ["//tensorflow_lite_support/cc/port/default:tflite_wrapper"],
+)
+
+# This is identical to the rule above, except that it gets built with
+# '-DTFLITE_USE_C_API'. This rule is used for unit tests that verify things
+# work correctly when built with TFLITE_USE_C_API defined.
+cc_library(
+ name = "tflite_wrapper_with_c_api_for_test",
+ testonly = 1,
+ hdrs = ["tflite_wrapper.h"],
+ deps = [
+ "//intelligence/mobile_acceleration/proto:allowlist_portable_proto",
+ "//intelligence/mobile_acceleration/support_library:tflite_wrapper_with_c_api_for_test",
+ ],
+)
+
+cc_library(
+ name = "integral_types",
+ hdrs = ["integral_types.h"],
+)
+
+cc_library(
+ name = "gtest",
+ testonly = 1,
+ hdrs = [
+ "gmock.h",
+ "gtest.h",
+ ],
+ deps = [
+ "//testing/base/public:gunit_for_library_testonly",
+ ],
+)
+
+cc_library(
+ name = "gtest_main",
+ testonly = 1,
+ hdrs = [
+ "benchmark.h",
+ "gmock.h",
+ "gtest.h",
+ ],
+ deps = [
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow_lite_support/cc/port/benchmark.h b/tensorflow_lite_support/cc/port/benchmark.h
new file mode 100644
index 00000000..74bc1a68
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/benchmark.h
@@ -0,0 +1,21 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_
+
+#include "gtest/benchmark.h"
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_
diff --git a/tensorflow_lite_support/cc/port/build_defs.bzl b/tensorflow_lite_support/cc/port/build_defs.bzl
new file mode 100644
index 00000000..a8053db2
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/build_defs.bzl
@@ -0,0 +1,30 @@
+""".bzl file for TFLite Support open source build configs."""
+
+load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library")
+
+def provided_args(**kwargs):
+ """Returns the keyword arguments omitting None arguments."""
+ return {k: v for k, v in kwargs.items() if v != None}
+
+def support_cc_proto_library(name, srcs, visibility = None, deps = [], cc_deps = [], testonly = 0):
+ """Generate cc_proto_library for TFLite Support open source version.
+
+ Args:
+ name: the name of the cc_proto_library.
+ srcs: the .proto files of the cc_proto_library for Bazel use.
+ visibility: visibility of this target.
+ deps: a list of dependency labels for Bazel use; must be cc_proto_library.
+ testonly: test only proto or not.
+ """
+ _ignore = [deps]
+ cc_proto_library(**provided_args(
+ name = name,
+ srcs = srcs,
+ visibility = visibility,
+ deps = cc_deps,
+ testonly = testonly,
+ cc_libs = ["@com_google_protobuf//:protobuf"],
+ protoc = "@com_google_protobuf//:protoc",
+ default_runtime = "@com_google_protobuf//:protobuf",
+ alwayslink = 1,
+ ))
diff --git a/tensorflow_lite_support/cc/port/default/BUILD b/tensorflow_lite_support/cc/port/default/BUILD
new file mode 100644
index 00000000..3f6e9e93
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/default/BUILD
@@ -0,0 +1,50 @@
+package(
+ default_visibility = [
+ "//tensorflow_lite_support/cc/port:__pkg__",
+ "//tensorflow_lite_support/cc/test:__pkg__",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "statusor",
+ srcs = ["statusor.cc"],
+ hdrs = [
+ "statusor.h",
+ "statusor_internals.h",
+ ],
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/meta:type_traits",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:variant",
+ "@com_google_absl//absl/utility",
+ "@com_google_glog//:glog",
+ ],
+)
+
+cc_library(
+ name = "status_macros",
+ hdrs = [
+ "status_macros.h",
+ ],
+ deps = [
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/status",
+ ],
+)
+
+cc_library(
+ name = "tflite_wrapper",
+ srcs = ["tflite_wrapper.cc"],
+ hdrs = [
+ "tflite_wrapper.h",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "@com_google_absl//absl/status",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_cc_proto",
+ ],
+)
diff --git a/tensorflow_lite_support/cc/port/default/status_macros.h b/tensorflow_lite_support/cc/port/default/status_macros.h
new file mode 100644
index 00000000..47476c9c
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/default/status_macros.h
@@ -0,0 +1,215 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file is forked from absl.
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_
+
+#include "absl/base/optimization.h"
+#include "absl/status/status.h"
+
+// Evaluates an expression that produces a `absl::Status`. If the status is not
+// ok, returns it from the current function.
+//
+// For example:
+// absl::Status MultiStepFunction() {
+// RETURN_IF_ERROR(Function(args...));
+// RETURN_IF_ERROR(foo.Method(args...));
+// return absl::OkStatus();
+// }
+#define RETURN_IF_ERROR(expr) \
+ STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
+ if (::tflite::support::status_macro_internal::StatusAdaptorForMacros \
+ status_macro_internal_adaptor = {(expr)}) { \
+ } else /* NOLINT */ \
+ return status_macro_internal_adaptor.Consume()
+
+// Executes an expression `rexpr` that returns a `tflite::support::StatusOr<T>`.
+// On OK, moves its value into the variable defined by `lhs`, otherwise returns
+// from the current function. By default the error status is returned
+// unchanged, but it may be modified by an `error_expression`. If there is an
+// error, `lhs` is not evaluated; thus any side effects that `lhs` may have
+// only occur in the success case.
+//
+// Interface:
+//
+// ASSIGN_OR_RETURN(lhs, rexpr)
+// ASSIGN_OR_RETURN(lhs, rexpr, error_expression);
+//
+// WARNING: if lhs is parenthesized, the parentheses are removed. See examples
+// for more details.
+//
+// WARNING: expands into multiple statements; it cannot be used in a single
+// statement (e.g. as the body of an if statement without {})!
+//
+// Example: Declaring and initializing a new variable (ValueType can be anything
+// that can be initialized with assignment, including references):
+// ASSIGN_OR_RETURN(ValueType value, MaybeGetValue(arg));
+//
+// Example: Assigning to an existing variable:
+// ValueType value;
+// ASSIGN_OR_RETURN(value, MaybeGetValue(arg));
+//
+// Example: Assigning to an expression with side effects:
+// MyProto data;
+// ASSIGN_OR_RETURN(*data.mutable_str(), MaybeGetValue(arg));
+// // No field "str" is added on error.
+//
+// Example: Assigning to a std::unique_ptr.
+// ASSIGN_OR_RETURN(std::unique_ptr<T> ptr, MaybeGetPtr(arg));
+//
+// Example: Assigning to a map. Because of C preprocessor
+// limitation, the type used in ASSIGN_OR_RETURN cannot contain comma, so
+// wrap lhs in parentheses:
+// ASSIGN_OR_RETURN((absl::flat_hash_map<Foo, Bar> my_map), GetMap());
+// Or use auto if the type is obvious enough:
+// ASSIGN_OR_RETURN(const auto& my_map, GetMapRef());
+//
+// Example: Assigning to structured bindings. The same situation with comma as
+// in map, so wrap the statement in parentheses.
+// ASSIGN_OR_RETURN((const auto& [first, second]), GetPair());
+
+#define ASSIGN_OR_RETURN(...) \
+ STATUS_MACROS_IMPL_GET_VARIADIC_((__VA_ARGS__, \
+ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \
+ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \
+ (__VA_ARGS__)
+
+// =================================================================
+// == Implementation details, do not rely on anything below here. ==
+// =================================================================
+
+// Some builds do not support C++14 fully yet, using C++11 constexpr technique.
+constexpr bool TFLSHasPotentialConditionalOperator(const char* lhs, int index) {
+ return (index == -1
+ ? false
+ : (lhs[index] == '?'
+ ? true
+ : TFLSHasPotentialConditionalOperator(lhs, index - 1)));
+}
+
+// MSVC incorrectly expands variadic macros, splice together a macro call to
+// work around the bug.
+#define STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME
+#define STATUS_MACROS_IMPL_GET_VARIADIC_(args) \
+ STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args
+
+#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \
+ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, _)
+#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \
+ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \
+ STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \
+ error_expression)
+#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \
+ error_expression) \
+ auto statusor = (rexpr); \
+ if (ABSL_PREDICT_FALSE(!statusor.ok())) { \
+ ::absl::Status _(std::move(statusor).status()); \
+ (void)_; /* error_expression is allowed to not use this variable */ \
+ return (error_expression); \
+ } \
+ { \
+ static_assert( \
+ #lhs[0] != '(' || #lhs[sizeof(#lhs) - 2] != ')' || \
+ !TFLSHasPotentialConditionalOperator(#lhs, sizeof(#lhs) - 2), \
+ "Identified potential conditional operator, consider not " \
+ "using ASSIGN_OR_RETURN"); \
+ } \
+ STATUS_MACROS_IMPL_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \
+ std::move(statusor).value()
+
+// Internal helpers for macro expansion.
+#define STATUS_MACROS_IMPL_EAT(...)
+#define STATUS_MACROS_IMPL_REM(...) __VA_ARGS__
+#define STATUS_MACROS_IMPL_EMPTY()
+
+// Internal helpers for emptyness arguments check.
+#define STATUS_MACROS_IMPL_IS_EMPTY_INNER(...) \
+ STATUS_MACROS_IMPL_IS_EMPTY_INNER_I(__VA_ARGS__, 0, 1)
+#define STATUS_MACROS_IMPL_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty
+
+#define STATUS_MACROS_IMPL_IS_EMPTY(...) \
+ STATUS_MACROS_IMPL_IS_EMPTY_I(__VA_ARGS__)
+#define STATUS_MACROS_IMPL_IS_EMPTY_I(...) \
+ STATUS_MACROS_IMPL_IS_EMPTY_INNER(_, ##__VA_ARGS__)
+
+// Internal helpers for if statement.
+#define STATUS_MACROS_IMPL_IF_1(_Then, _Else) _Then
+#define STATUS_MACROS_IMPL_IF_0(_Then, _Else) _Else
+#define STATUS_MACROS_IMPL_IF(_Cond, _Then, _Else) \
+ STATUS_MACROS_IMPL_CONCAT_(STATUS_MACROS_IMPL_IF_, _Cond) \
+ (_Then, _Else)
+
+// Expands to 1 if the input is parenthesized. Otherwise expands to 0.
+#define STATUS_MACROS_IMPL_IS_PARENTHESIZED(...) \
+ STATUS_MACROS_IMPL_IS_EMPTY(STATUS_MACROS_IMPL_EAT __VA_ARGS__)
+
+// If the input is parenthesized, removes the parentheses. Otherwise expands to
+// the input unchanged.
+#define STATUS_MACROS_IMPL_UNPARENTHESIZE_IF_PARENTHESIZED(...) \
+ STATUS_MACROS_IMPL_IF(STATUS_MACROS_IMPL_IS_PARENTHESIZED(__VA_ARGS__), \
+ STATUS_MACROS_IMPL_REM, STATUS_MACROS_IMPL_EMPTY()) \
+ __VA_ARGS__
+
+// Internal helper for concatenating macro values.
+#define STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y
+#define STATUS_MACROS_IMPL_CONCAT_(x, y) STATUS_MACROS_IMPL_CONCAT_INNER_(x, y)
+
+// The GNU compiler emits a warning for code like:
+//
+// if (foo)
+// if (bar) { } else baz;
+//
+// because it thinks you might want the else to bind to the first if. This
+// leads to problems with code like:
+//
+// if (do_expr) RETURN_IF_ERROR(expr) << "Some message";
+//
+// The "switch (0) case 0:" idiom is used to suppress this.
+#define STATUS_MACROS_IMPL_ELSE_BLOCKER_ \
+ switch (0) \
+ case 0: \
+ default: // NOLINT
+
+namespace tflite {
+namespace support {
+namespace status_macro_internal {
+
+// Provides a conversion to bool so that it can be used inside an if statement
+// that declares a variable.
+class StatusAdaptorForMacros {
+ public:
+ StatusAdaptorForMacros(const ::absl::Status& status) // NOLINT
+ : status_(status) {}
+
+ StatusAdaptorForMacros(::absl::Status&& status) // NOLINT
+ : status_(std::move(status)) {}
+
+ StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete;
+ StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete;
+
+ explicit operator bool() const { return ABSL_PREDICT_TRUE(status_.ok()); }
+
+ ::absl::Status&& Consume() { return std::move(status_); }
+
+ private:
+ ::absl::Status status_;
+};
+
+} // namespace status_macro_internal
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_
diff --git a/tensorflow_lite_support/cc/port/default/statusor.cc b/tensorflow_lite_support/cc/port/default/statusor.cc
new file mode 100644
index 00000000..5cf1196a
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/default/statusor.cc
@@ -0,0 +1,65 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file is forked from absl.
+
+#include "tensorflow_lite_support/cc/port/default/statusor.h"
+
+#include <utility>
+
+#include <glog/logging.h>
+#include "absl/strings/str_cat.h"
+
+namespace tflite {
+namespace support {
+
+BadStatusOrAccess::BadStatusOrAccess(absl::Status status)
+ : status_(std::move(status)) {}
+
+BadStatusOrAccess::~BadStatusOrAccess() = default;
+
+const char* BadStatusOrAccess::what() const noexcept {
+ return "Bad StatusOr access";
+}
+
+const absl::Status& BadStatusOrAccess::status() const { return status_; }
+
+namespace internal_statusor {
+
+void Helper::HandleInvalidStatusCtorArg(absl::Status* status) {
+ const char* kMessage =
+ "An OK status is not a valid constructor argument to StatusOr<T>";
+ LOG(DFATAL) << kMessage;
+ // In optimized builds, we will fall back to ::util::error::INTERNAL.
+ *status = absl::InternalError(kMessage);
+}
+
+void Helper::Crash(const absl::Status& status) {
+ LOG(FATAL) << "Attempting to fetch value instead of handling error "
+ << status;
+ _exit(1);
+}
+
+void ThrowBadStatusOrAccess(absl::Status status) {
+#ifdef ABSL_HAVE_EXCEPTIONS
+ throw BadStatusOrAccess(std::move(status));
+#else
+ LOG(FATAL) << "Attempting to fetch value instead of handling error "
+ << status;
+#endif
+}
+
+} // namespace internal_statusor
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/port/default/statusor.h b/tensorflow_lite_support/cc/port/default/statusor.h
new file mode 100644
index 00000000..4273e1ce
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/default/statusor.h
@@ -0,0 +1,574 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file is forked from absl.
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_
+
+#include <exception>
+#include <initializer_list>
+#include <new>
+#include <string>
+#include <type_traits>
+#include <utility>
+
+#include "absl/base/optimization.h"
+#include "absl/meta/type_traits.h"
+#include "absl/status/status.h"
+#include "absl/types/variant.h"
+#include "absl/utility/utility.h"
+#include "tensorflow_lite_support/cc/port/default/statusor_internals.h"
+
+namespace tflite {
+namespace support {
+
+#ifndef SWIG
+class BadStatusOrAccess : public std::exception {
+ public:
+ explicit BadStatusOrAccess(absl::Status status);
+ ~BadStatusOrAccess() override;
+ const char* what() const noexcept override;
+ const absl::Status& status() const;
+
+ private:
+ absl::Status status_;
+};
+#endif // !SWIG
+
+// Returned StatusOr objects may not be ignored.
+// Note: Disabled for SWIG as it doesn't parse attributes correctly. Codesearch
+// doesn't handle ifdefs as part of a class definitions (b/6995610), so we use a
+// forward declaration.
+#ifndef SWIG
+template <typename T>
+class ABSL_MUST_USE_RESULT StatusOr;
+#endif
+
+template <typename T>
+class StatusOr : private internal_statusor::StatusOrData<T>,
+ private internal_statusor::CopyCtorBase<T>,
+ private internal_statusor::MoveCtorBase<T>,
+ private internal_statusor::CopyAssignBase<T>,
+ private internal_statusor::MoveAssignBase<T> {
+ template <typename U>
+ friend class StatusOr;
+
+ typedef internal_statusor::StatusOrData<T> Base;
+
+ public:
+ typedef T value_type;
+
+ // Constructs a new StatusOr with Status::UNKNOWN status. This is marked
+ // 'explicit' to try to catch cases like 'return {};', where people think
+ // tflite::support::StatusOr<std::vector<int>> will be initialized with an
+ // empty vector, instead of a Status::UNKNOWN status.
+ explicit StatusOr();
+
+ // StatusOr<T> is copy constructible if T is copy constructible.
+ StatusOr(const StatusOr&) = default;
+ // StatusOr<T> is copy assignable if T is copy constructible and copy
+ // assignable.
+ StatusOr& operator=(const StatusOr&) = default;
+
+#ifndef SWIG
+
+ // StatusOr<T> is move constructible if T is move constructible.
+ StatusOr(StatusOr&&) = default;
+ // StatusOr<T> is moveAssignable if T is move constructible and move
+ // assignable.
+ StatusOr& operator=(StatusOr&&) = default;
+
+ // Converting constructors from StatusOr<U>, when T is constructible from U.
+ // To avoid ambiguity, they are disabled if T is also constructible from
+ // StatusOr<U>. Explicit iff the corresponding construction of T from U is
+ // explicit.
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ std::is_convertible<const U&, T>,
+ absl::negation<
+ internal_statusor::IsConstructibleOrConvertibleFromStatusOr<
+ T, U>>>::value,
+ int> = 0>
+ StatusOr(const StatusOr<U>& other) // NOLINT
+ : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ absl::negation<std::is_convertible<const U&, T>>,
+ absl::negation<
+ internal_statusor::IsConstructibleOrConvertibleFromStatusOr<
+ T, U>>>::value,
+ int> = 0>
+ explicit StatusOr(const StatusOr<U>& other)
+ : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
+
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
+ std::is_convertible<U&&, T>,
+ absl::negation<
+ internal_statusor::IsConstructibleOrConvertibleFromStatusOr<
+ T, U>>>::value,
+ int> = 0>
+ StatusOr(StatusOr<U>&& other) // NOLINT
+ : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
+ absl::negation<std::is_convertible<U&&, T>>,
+ absl::negation<
+ internal_statusor::IsConstructibleOrConvertibleFromStatusOr<
+ T, U>>>::value,
+ int> = 0>
+ explicit StatusOr(StatusOr<U>&& other)
+ : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
+
+ // Conversion copy/move assignment operator, T must be constructible and
+ // assignable from U. Only enable if T cannot be directly assigned from
+ // StatusOr<U>.
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>,
+ std::is_constructible<T, const U&>,
+ std::is_assignable<T, const U&>,
+ absl::negation<
+ internal_statusor::
+ IsConstructibleOrConvertibleOrAssignableFromStatusOr<
+ T, U>>>::value,
+ int> = 0>
+ StatusOr& operator=(const StatusOr<U>& other) {
+ this->Assign(other);
+ return *this;
+ }
+ template <
+ typename U,
+ absl::enable_if_t<
+ absl::conjunction<
+ absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>,
+ std::is_assignable<T, U&&>,
+ absl::negation<
+ internal_statusor::
+ IsConstructibleOrConvertibleOrAssignableFromStatusOr<
+ T, U>>>::value,
+ int> = 0>
+ StatusOr& operator=(StatusOr<U>&& other) {
+ this->Assign(std::move(other));
+ return *this;
+ }
+
+#endif // SWIG
+
+ // Constructs a new StatusOr with the given value. After calling this
+ // constructor, this->ok() will be true and the contained value may be
+ // retrieved with value(), operator*(), or operator->().
+ //
+ // NOTE: Not explicit - we want to use StatusOr<T> as a return type
+ // so it is convenient and sensible to be able to do 'return T()'
+ // when the return type is StatusOr<T>.
+ //
+ // REQUIRES: T is copy constructible.
+ // TODO(b/113125838): Replace this constructor with a direct-initialization
+ // constructor.
+ StatusOr(const T& value);
+
+ // Constructs a new StatusOr with the given non-ok status. After calling this
+ // constructor, this->ok() will be false and calls to value() will CHECK-fail.
+ //
+ // NOTE: Not explicit - we want to use StatusOr<T> as a return
+ // value, so it is convenient and sensible to be able to do 'return
+ // Status()' when the return type is StatusOr<T>.
+ //
+ // REQUIRES: !status.ok(). This requirement is DCHECKed.
+ // In optimized builds, passing util::OkStatus() here will have the effect
+ // of passing util::error::INTERNAL as a fallback.
+ StatusOr(const absl::Status& status);
+ StatusOr& operator=(const absl::Status& status);
+
+#ifndef SWIG
+ // Perfect-forwarding value assignment operator.
+ // If `*this` contains a `T` value before the call, the contained value is
+ // assigned from `std::forward<U>(v)`; Otherwise, it is directly-initialized
+ // from `std::forward<U>(v)`.
+ // This function does not participate in overload unless:
+ // 1. `std::is_constructible_v<T, U>` is true,
+ // 2. `std::is_assignable_v<T&, U>` is true.
+ // 3. `std::is_same_v<StatusOr<T>, std::remove_cvref_t<U>>` is false.
+ // 4. Assigning `U` to `T` is not ambiguous:
+ // If `U` is `StatusOr<V>` and `T` is constructible and assignable from
+ // both `StatusOr<V>` and `V`, the assignment is considered bug-prone and
+ // ambiguous thus will fail to compile. For example:
+ // StatusOr<bool> s1 = true; // s1.ok() && *s1 == true
+ // StatusOr<bool> s2 = false; // s2.ok() && *s2 == false
+ // s1 = s2; // ambiguous, `s1 = *s2` or `s1 = bool(s2)`?
+ template <
+ typename U = T,
+ typename = typename std::enable_if<absl::conjunction<
+ std::is_constructible<T, U&&>, std::is_assignable<T&, U&&>,
+ internal_statusor::IsForwardingAssignmentValid<T, U&&>>::value>::type>
+ StatusOr& operator=(U&& v) {
+ this->Assign(std::forward<U>(v));
+ return *this;
+ }
+
+ // Similar to the `const T&` overload.
+ //
+ // REQUIRES: T is move constructible.
+ StatusOr(T&& value);
+
+ // RValue versions of the operations declared above.
+ StatusOr(absl::Status&& status);
+ StatusOr& operator=(absl::Status&& status);
+
+ // Constructs the inner value T in-place using the provided args, using the
+ // T(args...) constructor.
+ template <typename... Args>
+ explicit StatusOr(absl::in_place_t, Args&&... args);
+ template <typename U, typename... Args>
+ explicit StatusOr(absl::in_place_t, std::initializer_list<U> ilist,
+ Args&&... args);
+
+ // Constructs the inner value T in-place using the provided args, using the
+ // T(U) (direct-initialization) constructor. Only valid if T can be
+ // constructed from a U. Can accept move or copy constructors. Explicit if
+ // U is not convertible to T. To avoid ambiguity, this is disabled if U is
+ // a StatusOr<J>, where J is convertible to T.
+ // Style waiver for implicit conversion granted in cl/209187539.
+ template <typename U = T,
+ absl::enable_if_t<
+ absl::conjunction<
+ internal_statusor::IsDirectInitializationValid<T, U&&>,
+ std::is_constructible<T, U&&>,
+ std::is_convertible<U&&, T>>::value,
+ int> = 0>
+ StatusOr(U&& u) // NOLINT
+ : StatusOr(absl::in_place, std::forward<U>(u)) {}
+
+ template <typename U = T,
+ absl::enable_if_t<
+ absl::conjunction<
+ internal_statusor::IsDirectInitializationValid<T, U&&>,
+ std::is_constructible<T, U&&>,
+ absl::negation<std::is_convertible<U&&, T>>>::value,
+ int> = 0>
+ explicit StatusOr(U&& u) // NOLINT
+ : StatusOr(absl::in_place, std::forward<U>(u)) {}
+
+#endif // SWIG
+
+ // Returns this->status().ok()
+ ABSL_MUST_USE_RESULT bool ok() const { return this->status_.ok(); }
+
+ // Returns a reference to our status. If this contains a T, then
+ // returns util::OkStatus().
+#ifdef SWIG
+ const ::util::Status& status() const;
+#else // SWIG
+ const absl::Status& status() const&;
+ absl::Status status() &&;
+#endif // SWIG
+
+ // Returns a reference to the held value if `this->ok()`. Otherwise, throws
+ // `absl::BadStatusOrAccess` if exception is enabled, or `LOG(FATAL)` if
+ // exception is disabled.
+ // If you have already checked the status using `this->ok()` or
+ // `operator bool()`, you probably want to use `operator*()` or `operator->()`
+ // to access the value instead of `value`.
+ // Note: for value types that are cheap to copy, prefer simple code:
+ //
+ // T value = statusor.value();
+ //
+ // Otherwise, if the value type is expensive to copy, but can be left
+ // in the StatusOr, simply assign to a reference:
+ //
+ // T& value = statusor.value(); // or `const T&`
+ //
+ // Otherwise, if the value type supports an efficient move, it can be
+ // used as follows:
+ //
+ // T value = std::move(statusor).value();
+ //
+ // The `std::move` on statusor instead of on the whole expression enables
+ // warnings about possible uses of the statusor object after the move.
+#ifdef SWIG
+ const T& value() const;
+#else // SWIG
+ const T& value() const&;
+ T& value() &;
+ const T&& value() const&&;
+ T&& value() &&;
+#endif // SWIG
+
+#ifndef SWIG
+ // Returns a reference to the current value.
+ //
+ // REQUIRES: this->ok() == true, otherwise the behavior is undefined.
+ //
+ // Use this->ok() or `operator bool()` to verify that there is a current
+ // value. Alternatively, see value() for a similar API that guarantees
+ // CHECK-failing if there is no current value.
+ const T& operator*() const&;
+ T& operator*() &;
+ const T&& operator*() const&&;
+ T&& operator*() &&;
+#endif // SWIG
+
+#ifndef SWIG
+ // Returns a pointer to the current value.
+ //
+ // REQUIRES: this->ok() == true, otherwise the behavior is undefined.
+ //
+ // Use this->ok() or `operator bool()` to verify that there is a current
+ // value.
+ const T* operator->() const;
+ T* operator->();
+#endif // SWIG
+
+#ifndef SWIG
+ // Returns a copy of the current value if this->ok() == true. Otherwise
+ // returns a default value.
+ template <typename U>
+ T value_or(U&& default_value) const&;
+ template <typename U>
+ T value_or(U&& default_value) &&;
+#endif // SWIG
+
+ // Ignores any errors. This method does nothing except potentially suppress
+ // complaints from any tools that are checking that errors are not dropped on
+ // the floor.
+ void IgnoreError() const;
+
+#ifndef SWIG
+ // Reconstructs the inner value T in-place using the provided args, using the
+ // T(args...) constructor. Returns reference to the reconstructed `T`.
+ template <typename... Args>
+ T& emplace(Args&&... args) {
+ if (ok()) {
+ this->Clear();
+ this->MakeValue(std::forward<Args>(args)...);
+ } else {
+ this->MakeValue(std::forward<Args>(args)...);
+ this->status_ = absl::OkStatus();
+ }
+ return this->data_;
+ }
+
+ template <
+ typename U, typename... Args,
+ absl::enable_if_t<
+ std::is_constructible<T, std::initializer_list<U>&, Args&&...>::value,
+ int> = 0>
+ T& emplace(std::initializer_list<U> ilist, Args&&... args) {
+ if (ok()) {
+ this->Clear();
+ this->MakeValue(ilist, std::forward<Args>(args)...);
+ } else {
+ this->MakeValue(ilist, std::forward<Args>(args)...);
+ this->status_ = absl::OkStatus();
+ }
+ return this->data_;
+ }
+#endif // SWIG
+
+ private:
+#ifndef SWIG
+ using internal_statusor::StatusOrData<T>::Assign;
+ template <typename U>
+ void Assign(const StatusOr<U>& other);
+ template <typename U>
+ void Assign(StatusOr<U>&& other);
+#endif // SWIG
+};
+
+#ifndef SWIG
+////////////////////////////////////////////////////////////////////////////////
+// Implementation details for StatusOr<T>
+
+template <typename T>
+tflite::support::StatusOr<T>::StatusOr()
+ : Base(absl::Status(absl::StatusCode::kUnknown, "")) {}
+
+template <typename T>
+tflite::support::StatusOr<T>::StatusOr(const T& value) : Base(value) {}
+
+template <typename T>
+tflite::support::StatusOr<T>::StatusOr(const absl::Status& status)
+ : Base(status) {}
+
+template <typename T>
+tflite::support::StatusOr<T>& StatusOr<T>::operator=(
+ const absl::Status& status) {
+ this->Assign(status);
+ return *this;
+}
+
+template <typename T>
+tflite::support::StatusOr<T>::StatusOr(T&& value) : Base(std::move(value)) {}
+
+template <typename T>
+tflite::support::StatusOr<T>::StatusOr(absl::Status&& status)
+ : Base(std::move(status)) {}
+
+template <typename T>
+tflite::support::StatusOr<T>& StatusOr<T>::operator=(absl::Status&& status) {
+ this->Assign(std::move(status));
+ return *this;
+}
+
+template <typename T>
+template <typename U>
+inline void StatusOr<T>::Assign(const StatusOr<U>& other) {
+ if (other.ok()) {
+ this->Assign(other.value());
+ } else {
+ this->Assign(other.status());
+ }
+}
+
+template <typename T>
+template <typename U>
+inline void StatusOr<T>::Assign(StatusOr<U>&& other) {
+ if (other.ok()) {
+ this->Assign(std::move(other).value());
+ } else {
+ this->Assign(std::move(other).status());
+ }
+}
+template <typename T>
+template <typename... Args>
+tflite::support::StatusOr<T>::StatusOr(absl::in_place_t, Args&&... args)
+ : Base(absl::in_place, std::forward<Args>(args)...) {}
+
+template <typename T>
+template <typename U, typename... Args>
+tflite::support::StatusOr<T>::StatusOr(absl::in_place_t,
+ std::initializer_list<U> ilist,
+ Args&&... args)
+ : Base(absl::in_place, ilist, std::forward<Args>(args)...) {}
+
+template <typename T>
+const absl::Status& StatusOr<T>::status() const& {
+ return this->status_;
+}
+template <typename T>
+absl::Status StatusOr<T>::status() && {
+ return ok() ? absl::OkStatus() : std::move(this->status_);
+}
+
+template <typename T>
+const T& StatusOr<T>::value() const& {
+ if (!this->ok()) internal_statusor::ThrowBadStatusOrAccess(this->status_);
+ return this->data_;
+}
+
+template <typename T>
+T& StatusOr<T>::value() & {
+ if (!this->ok()) internal_statusor::ThrowBadStatusOrAccess(this->status_);
+ return this->data_;
+}
+
+template <typename T>
+const T&& StatusOr<T>::value() const&& {
+ if (!this->ok()) {
+ internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_));
+ }
+ return std::move(this->data_);
+}
+
+template <typename T>
+T&& StatusOr<T>::value() && {
+ if (!this->ok()) {
+ internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_));
+ }
+ return std::move(this->data_);
+}
+
+template <typename T>
+const T& StatusOr<T>::operator*() const& {
+ this->EnsureOk();
+ return this->data_;
+}
+
+template <typename T>
+T& StatusOr<T>::operator*() & {
+ this->EnsureOk();
+ return this->data_;
+}
+
+template <typename T>
+const T&& StatusOr<T>::operator*() const&& {
+ this->EnsureOk();
+ return std::move(this->data_);
+}
+
+template <typename T>
+T&& StatusOr<T>::operator*() && {
+ this->EnsureOk();
+ return std::move(this->data_);
+}
+
+template <typename T>
+const T* StatusOr<T>::operator->() const {
+ this->EnsureOk();
+ return &this->data_;
+}
+
+template <typename T>
+T* StatusOr<T>::operator->() {
+ this->EnsureOk();
+ return &this->data_;
+}
+
+template <typename T>
+template <typename U>
+T StatusOr<T>::value_or(U&& default_value) const& {
+ if (ok()) {
+ return this->data_;
+ }
+ return std::forward<U>(default_value);
+}
+
+template <typename T>
+template <typename U>
+T StatusOr<T>::value_or(U&& default_value) && {
+ if (ok()) {
+ return std::move(this->data_);
+ }
+ return std::forward<U>(default_value);
+}
+
+template <typename T>
+void StatusOr<T>::IgnoreError() const {
+ // no-op
+}
+
+#endif // SWIG
+
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_
diff --git a/tensorflow_lite_support/cc/port/default/statusor_internals.h b/tensorflow_lite_support/cc/port/default/statusor_internals.h
new file mode 100644
index 00000000..56d46616
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/default/statusor_internals.h
@@ -0,0 +1,409 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file is forked from absl.
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_INTERNALS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_INTERNALS_H_
+
+#include <type_traits>
+#include <utility>
+
+#include "absl/meta/type_traits.h"
+#include "absl/status/status.h"
+#include "absl/utility/utility.h"
+
+namespace tflite {
+namespace support {
+
+template <typename T>
+class ABSL_MUST_USE_RESULT StatusOr;
+
+namespace internal_statusor {
+
+// Detects whether `T` is constructible or convertible from `StatusOr<U>`.
+template <typename T, typename U>
+using IsConstructibleOrConvertibleFromStatusOr =
+ absl::disjunction<std::is_constructible<T, StatusOr<U>&>,
+ std::is_constructible<T, const StatusOr<U>&>,
+ std::is_constructible<T, StatusOr<U>&&>,
+ std::is_constructible<T, const StatusOr<U>&&>,
+ std::is_convertible<StatusOr<U>&, T>,
+ std::is_convertible<const StatusOr<U>&, T>,
+ std::is_convertible<StatusOr<U>&&, T>,
+ std::is_convertible<const StatusOr<U>&&, T>>;
+
+// Detects whether `T` is constructible or convertible or assignable from
+// `StatusOr<U>`.
+template <typename T, typename U>
+using IsConstructibleOrConvertibleOrAssignableFromStatusOr =
+ absl::disjunction<IsConstructibleOrConvertibleFromStatusOr<T, U>,
+ std::is_assignable<T&, StatusOr<U>&>,
+ std::is_assignable<T&, const StatusOr<U>&>,
+ std::is_assignable<T&, StatusOr<U>&&>,
+ std::is_assignable<T&, const StatusOr<U>&&>>;
+
+// Detects whether direct initializing `StatusOr<T>` from `U` is ambiguous, i.e.
+// when `U` is `StatusOr<V>` and `T` is constructible or convertible from `V`.
+template <typename T, typename U>
+struct IsDirectInitializationAmbiguous
+ : public absl::conditional_t<
+ std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>,
+ U>::value,
+ std::false_type,
+ IsDirectInitializationAmbiguous<
+ T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
+
+template <typename T, typename V>
+struct IsDirectInitializationAmbiguous<T, tflite::support::StatusOr<V>>
+ : public IsConstructibleOrConvertibleFromStatusOr<T, V> {};
+
+// Checks against the constraints of the direction initialization, i.e. when
+// `StatusOr<T>::StatusOr(U&&)` should participate in overload resolution.
+template <typename T, typename U>
+using IsDirectInitializationValid = absl::disjunction<
+ // Short circuits if T is basically U.
+ std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ absl::negation<absl::disjunction<
+ std::is_same<tflite::support::StatusOr<T>,
+ absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<absl::Status,
+ absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<absl::in_place_t,
+ absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ IsDirectInitializationAmbiguous<T, U>>>>;
+
+// This trait detects whether `StatusOr<T>::operator=(U&&)` is ambiguous, which
+// is equivalent to whether all the following conditions are met:
+// 1. `U` is `StatusOr<V>`.
+// 2. `T` is constructible and assignable from `V`.
+// 3. `T` is constructible and assignable from `U` (i.e. `StatusOr<V>`).
+// For example, the following code is considered ambiguous:
+// (`T` is `bool`, `U` is `StatusOr<bool>`, `V` is `bool`)
+// StatusOr<bool> s1 = true; // s1.ok() && s1.ValueOrDie() == true
+// StatusOr<bool> s2 = false; // s2.ok() && s2.ValueOrDie() == false
+// s1 = s2; // ambiguous, `s1 = s2.ValueOrDie()` or `s1 = bool(s2)`?
+template <typename T, typename U>
+struct IsForwardingAssignmentAmbiguous
+ : public absl::conditional_t<
+ std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>,
+ U>::value,
+ std::false_type,
+ IsForwardingAssignmentAmbiguous<
+ T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {};
+
+template <typename T, typename U>
+struct IsForwardingAssignmentAmbiguous<T, tflite::support::StatusOr<U>>
+ : public IsConstructibleOrConvertibleOrAssignableFromStatusOr<T, U> {};
+
+// Checks against the constraints of the forwarding assignment, i.e. whether
+// `StatusOr<T>::operator(U&&)` should participate in overload resolution.
+template <typename T, typename U>
+using IsForwardingAssignmentValid = absl::disjunction<
+ // Short circuits if T is basically U.
+ std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ absl::negation<absl::disjunction<
+ std::is_same<tflite::support::StatusOr<T>,
+ absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<absl::Status,
+ absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ std::is_same<absl::in_place_t,
+ absl::remove_cv_t<absl::remove_reference_t<U>>>,
+ IsForwardingAssignmentAmbiguous<T, U>>>>;
+
+class Helper {
+ public:
+ // Move type-agnostic error handling to the .cc.
+ static void HandleInvalidStatusCtorArg(absl::Status*);
+ ABSL_ATTRIBUTE_NORETURN static void Crash(const absl::Status& status);
+};
+
+// Construct an instance of T in `p` through placement new, passing Args... to
+// the constructor.
+// This abstraction is here mostly for the gcc performance fix.
+template <typename T, typename... Args>
+void PlacementNew(void* p, Args&&... args) {
+#if defined(__GNUC__) && !defined(__clang__)
+ // Teach gcc that 'p' cannot be null, fixing code size issues.
+ if (p == nullptr) __builtin_unreachable();
+#endif
+ new (p) T(std::forward<Args>(args)...);
+}
+
+// Helper base class to hold the data and all operations.
+// We move all this to a base class to allow mixing with the appropriate
+// TraitsBase specialization.
+template <typename T>
+class StatusOrData {
+ template <typename U>
+ friend class StatusOrData;
+
+ public:
+ StatusOrData() = delete;
+
+ StatusOrData(const StatusOrData& other) {
+ if (other.ok()) {
+ MakeValue(other.data_);
+ MakeStatus();
+ } else {
+ MakeStatus(other.status_);
+ }
+ }
+
+ StatusOrData(StatusOrData&& other) noexcept {
+ if (other.ok()) {
+ MakeValue(std::move(other.data_));
+ MakeStatus();
+ } else {
+ MakeStatus(std::move(other.status_));
+ }
+ }
+
+ template <typename U>
+ explicit StatusOrData(const StatusOrData<U>& other) {
+ if (other.ok()) {
+ MakeValue(other.data_);
+ MakeStatus();
+ } else {
+ MakeStatus(other.status_);
+ }
+ }
+
+ template <typename U>
+ explicit StatusOrData(StatusOrData<U>&& other) {
+ if (other.ok()) {
+ MakeValue(std::move(other.data_));
+ MakeStatus();
+ } else {
+ MakeStatus(std::move(other.status_));
+ }
+ }
+
+ template <typename... Args>
+ explicit StatusOrData(absl::in_place_t, Args&&... args)
+ : data_(std::forward<Args>(args)...) {
+ MakeStatus();
+ }
+
+ explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); }
+ explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); }
+
+ explicit StatusOrData(const absl::Status& status) : status_(status) {
+ EnsureNotOk();
+ }
+ explicit StatusOrData(absl::Status&& status) : status_(std::move(status)) {
+ EnsureNotOk();
+ }
+
+ StatusOrData& operator=(const StatusOrData& other) {
+ if (this == &other) return *this;
+ if (other.ok())
+ Assign(other.data_);
+ else
+ Assign(other.status_);
+ return *this;
+ }
+
+ StatusOrData& operator=(StatusOrData&& other) {
+ if (this == &other) return *this;
+ if (other.ok())
+ Assign(std::move(other.data_));
+ else
+ Assign(std::move(other.status_));
+ return *this;
+ }
+
+ ~StatusOrData() {
+ if (ok()) {
+ status_.~Status();
+ data_.~T();
+ } else {
+ status_.~Status();
+ }
+ }
+
+ // TODO(b/140189837): Remove the SFINAE condition after cleanup.
+ template <typename U,
+ absl::enable_if_t<std::is_assignable<T&, U&&>::value, int> = 0>
+ void Assign(U&& value) {
+ if (ok()) {
+ data_ = std::forward<U>(value);
+ } else {
+ MakeValue(std::forward<U>(value));
+ status_ = absl::OkStatus();
+ }
+ }
+
+ // TODO(b/140189837): Remove this after cleanup.
+ // This overload is to handle the case where `T` is a `const` type.
+ // `StatusOr` supports assignment for `const` types though it's forbidden by
+ // other standard types like `std::optional`.
+ template <typename U,
+ absl::enable_if_t<!std::is_assignable<T&, U&&>::value, int> = 0>
+ void Assign(U&& value) {
+ if (ok()) {
+ data_.~T();
+ MakeValue(std::forward<U>(value));
+ } else {
+ MakeValue(std::forward<U>(value));
+ status_ = absl::OkStatus();
+ }
+ }
+
+ void Assign(const absl::Status& status) {
+ Clear();
+ status_ = status;
+ EnsureNotOk();
+ }
+
+ void Assign(absl::Status&& status) {
+ Clear();
+ status_ = std::move(status);
+ EnsureNotOk();
+ }
+
+ bool ok() const { return status_.ok(); }
+
+ protected:
+ // status_ will always be active after the constructor.
+ // We make it a union to be able to initialize exactly how we need without
+ // waste.
+ // Eg. in the copy constructor we use the default constructor of Status in
+ // the ok() path to avoid an extra Ref call.
+ union {
+ absl::Status status_;
+ };
+
+ // data_ is active iff status_.ok()==true
+ struct Dummy {};
+ union {
+ // When T is const, we need some non-const object we can cast to void* for
+ // the placement new. dummy_ is that object.
+ Dummy dummy_;
+ T data_;
+ };
+
+ void Clear() {
+ if (ok()) data_.~T();
+ }
+
+ void EnsureOk() const {
+ if (ABSL_PREDICT_FALSE(!ok())) Helper::Crash(status_);
+ }
+
+ void EnsureNotOk() {
+ if (ABSL_PREDICT_FALSE(ok())) Helper::HandleInvalidStatusCtorArg(&status_);
+ }
+
+ // Construct the value (ie. data_) through placement new with the passed
+ // argument.
+ template <typename... Arg>
+ void MakeValue(Arg&&... arg) {
+ internal_statusor::PlacementNew<T>(&dummy_, std::forward<Arg>(arg)...);
+ }
+
+ // Construct the status (ie. status_) through placement new with the passed
+ // argument.
+ template <typename... Args>
+ void MakeStatus(Args&&... args) {
+ internal_statusor::PlacementNew<absl::Status>(&status_,
+ std::forward<Args>(args)...);
+ }
+};
+
+// Helper base classes to allow implicitly deleted constructors and assignment
+// operators in `StatusOr`. For example, `CopyCtorBase` will explicitly delete
+// the copy constructor when T is not copy constructible and `StatusOr` will
+// inherit that behavior implicitly.
+template <typename T, bool = std::is_copy_constructible<T>::value>
+struct CopyCtorBase {
+ CopyCtorBase() = default;
+ CopyCtorBase(const CopyCtorBase&) = default;
+ CopyCtorBase(CopyCtorBase&&) = default;
+ CopyCtorBase& operator=(const CopyCtorBase&) = default;
+ CopyCtorBase& operator=(CopyCtorBase&&) = default;
+};
+
+template <typename T>
+struct CopyCtorBase<T, false> {
+ CopyCtorBase() = default;
+ CopyCtorBase(const CopyCtorBase&) = delete;
+ CopyCtorBase(CopyCtorBase&&) = default;
+ CopyCtorBase& operator=(const CopyCtorBase&) = default;
+ CopyCtorBase& operator=(CopyCtorBase&&) = default;
+};
+
+template <typename T, bool = std::is_move_constructible<T>::value>
+struct MoveCtorBase {
+ MoveCtorBase() = default;
+ MoveCtorBase(const MoveCtorBase&) = default;
+ MoveCtorBase(MoveCtorBase&&) = default;
+ MoveCtorBase& operator=(const MoveCtorBase&) = default;
+ MoveCtorBase& operator=(MoveCtorBase&&) = default;
+};
+
+template <typename T>
+struct MoveCtorBase<T, false> {
+ MoveCtorBase() = default;
+ MoveCtorBase(const MoveCtorBase&) = default;
+ MoveCtorBase(MoveCtorBase&&) = delete;
+ MoveCtorBase& operator=(const MoveCtorBase&) = default;
+ MoveCtorBase& operator=(MoveCtorBase&&) = default;
+};
+
+template <typename T, bool = std::is_copy_constructible<T>::value&&
+ std::is_copy_assignable<T>::value>
+struct CopyAssignBase {
+ CopyAssignBase() = default;
+ CopyAssignBase(const CopyAssignBase&) = default;
+ CopyAssignBase(CopyAssignBase&&) = default;
+ CopyAssignBase& operator=(const CopyAssignBase&) = default;
+ CopyAssignBase& operator=(CopyAssignBase&&) = default;
+};
+
+template <typename T>
+struct CopyAssignBase<T, false> {
+ CopyAssignBase() = default;
+ CopyAssignBase(const CopyAssignBase&) = default;
+ CopyAssignBase(CopyAssignBase&&) = default;
+ CopyAssignBase& operator=(const CopyAssignBase&) = delete;
+ CopyAssignBase& operator=(CopyAssignBase&&) = default;
+};
+
+template <typename T, bool = std::is_move_constructible<T>::value&&
+ std::is_move_assignable<T>::value>
+struct MoveAssignBase {
+ MoveAssignBase() = default;
+ MoveAssignBase(const MoveAssignBase&) = default;
+ MoveAssignBase(MoveAssignBase&&) = default;
+ MoveAssignBase& operator=(const MoveAssignBase&) = default;
+ MoveAssignBase& operator=(MoveAssignBase&&) = default;
+};
+
+template <typename T>
+struct MoveAssignBase<T, false> {
+ MoveAssignBase() = default;
+ MoveAssignBase(const MoveAssignBase&) = default;
+ MoveAssignBase(MoveAssignBase&&) = default;
+ MoveAssignBase& operator=(const MoveAssignBase&) = default;
+ MoveAssignBase& operator=(MoveAssignBase&&) = delete;
+};
+
+void ThrowBadStatusOrAccess(absl::Status status);
+
+} // namespace internal_statusor
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_INTERNALS_H_
diff --git a/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
new file mode 100644
index 00000000..548e679a
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
@@ -0,0 +1,60 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h"
+
+#include "absl/status/status.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+
+namespace tflite {
+namespace support {
+
+absl::Status TfLiteInterpreterWrapper::InitializeWithFallback(
+ std::function<absl::Status(std::unique_ptr<tflite::Interpreter>*)>
+ interpreter_initializer,
+ const tflite::proto::ComputeSettings& compute_settings) {
+ if (compute_settings.has_preference() ||
+ compute_settings.has_tflite_settings()) {
+ return absl::UnimplementedError(
+ "Acceleration via ComputeSettings is not supported yet.");
+ }
+ RETURN_IF_ERROR(interpreter_initializer(&interpreter_));
+ return interpreter_->AllocateTensors() != kTfLiteOk
+ ? absl::InternalError(
+ "TFLite interpreter: AllocateTensors() failed.")
+ : absl::OkStatus();
+}
+
+absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
+ const std::function<absl::Status(tflite::Interpreter* interpreter)>&
+ set_inputs) {
+ RETURN_IF_ERROR(set_inputs(interpreter_.get()));
+ return interpreter_->Invoke() != kTfLiteOk
+ ? absl::InternalError("TFLite interpreter: Invoke() failed.")
+ : absl::OkStatus();
+}
+
+absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() {
+ return interpreter_->Invoke() != kTfLiteOk
+ ? absl::InternalError("TFLite interpreter: Invoke() failed.")
+ : absl::OkStatus();
+}
+
+void TfLiteInterpreterWrapper::Cancel() {
+ // NOP
+}
+
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/port/default/tflite_wrapper.h b/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
new file mode 100644
index 00000000..3fd489f7
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/default/tflite_wrapper.h
@@ -0,0 +1,82 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_
+
+#include <memory>
+#include <utility>
+
+#include "absl/status/status.h"
+#include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h"
+#include "tensorflow/lite/interpreter.h"
+
+namespace tflite {
+namespace support {
+
+// Wrapper for a TfLiteInterpreter that may be accelerated[1]. This is NOT yet
+// implemented: this class only provides a first, minimal interface in the
+// meanwhile.
+//
+// [1] See tensorflow/lite/experimental/acceleration for more details.
+class TfLiteInterpreterWrapper {
+ public:
+ TfLiteInterpreterWrapper() = default;
+
+ virtual ~TfLiteInterpreterWrapper() = default;
+
+ // Calls `interpreter_initializer` and then `AllocateTensors`. Future
+ // implementation of this method will attempt to apply the provided
+ // `compute_settings` with a graceful fallback in case a failure occurs.
+ // Note: before this gets implemented, do NOT call this method with non-empty
+ // `compute_settings` otherwise an unimplemented error occurs.
+ absl::Status InitializeWithFallback(
+ std::function<absl::Status(std::unique_ptr<tflite::Interpreter>*)>
+ interpreter_initializer,
+ const tflite::proto::ComputeSettings& compute_settings);
+
+ // Calls `set_inputs` and then Invoke() on the interpreter. Future
+ // implementation of this method will perform a graceful fallback in case a
+ // failure occur due to the `compute_settings` provided at initialization
+ // time.
+ absl::Status InvokeWithFallback(
+ const std::function<absl::Status(tflite::Interpreter* interpreter)>&
+ set_inputs);
+
+ // Calls Invoke() on the interpreter. Caller must have set up inputs
+ // before-hand.
+ absl::Status InvokeWithoutFallback();
+
+ // Cancels the current running TFLite invocation on CPU. This method is not
+ // yet implemented though it is safe to use it as it acts as a NOP.
+ void Cancel();
+
+ // Accesses the underlying interpreter for other methods.
+ tflite::Interpreter& operator*() { return *interpreter_; }
+ tflite::Interpreter* operator->() { return interpreter_.get(); }
+ tflite::Interpreter& operator*() const { return *interpreter_; }
+ tflite::Interpreter* operator->() const { return interpreter_.get(); }
+ tflite::Interpreter* get() const { return interpreter_.get(); }
+
+ TfLiteInterpreterWrapper(const TfLiteInterpreterWrapper&) = delete;
+ TfLiteInterpreterWrapper& operator=(const TfLiteInterpreterWrapper&) = delete;
+
+ private:
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+};
+
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_
diff --git a/tensorflow_lite_support/cc/port/gmock.h b/tensorflow_lite_support/cc/port/gmock.h
new file mode 100644
index 00000000..5e4334db
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/gmock.h
@@ -0,0 +1,21 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_
+
+#include "gmock/gmock.h"
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_
diff --git a/tensorflow_lite_support/cc/port/gtest.h b/tensorflow_lite_support/cc/port/gtest.h
new file mode 100644
index 00000000..dbe2e5e6
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/gtest.h
@@ -0,0 +1,21 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_
+
+#include "gtest/gtest.h"
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_
diff --git a/tensorflow_lite_support/cc/port/integral_types.h b/tensorflow_lite_support/cc/port/integral_types.h
new file mode 100644
index 00000000..76d9d503
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/integral_types.h
@@ -0,0 +1,46 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_INTEGRAL_TYPES_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_INTEGRAL_TYPES_H_
+
+// Add namespace here to avoid conflict with other libraries.
+namespace tflite {
+
+typedef signed char schar;
+typedef signed char int8;
+typedef short int16;
+typedef int int32;
+typedef long long int64;
+
+typedef unsigned char uint8;
+typedef unsigned short uint16;
+typedef unsigned int uint32;
+typedef unsigned int char32;
+typedef unsigned long long uint64;
+typedef unsigned long uword_t;
+
+#define GG_LONGLONG(x) x##LL
+#define GG_ULONGLONG(x) x##ULL
+#define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also.
+#define GG_LL_FORMAT_W L"ll"
+
+typedef uint64 Fprint;
+static const Fprint kIllegalFprint = 0;
+static const Fprint kMaxFprint = GG_ULONGLONG(0xFFFFFFFFFFFFFFFF);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_INTEGRAL_TYPES_H_
diff --git a/tensorflow_lite_support/cc/port/status_macros.h b/tensorflow_lite_support/cc/port/status_macros.h
new file mode 100644
index 00000000..3890c772
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/status_macros.h
@@ -0,0 +1,21 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MACROS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MACROS_H_
+
+#include "tensorflow_lite_support/cc/port/default/status_macros.h"
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MACROS_H_
diff --git a/tensorflow_lite_support/cc/port/statusor.h b/tensorflow_lite_support/cc/port/statusor.h
new file mode 100644
index 00000000..f84c7568
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/statusor.h
@@ -0,0 +1,20 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_
+
+#include "tensorflow_lite_support/cc/port/default/statusor.h"
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_
diff --git a/tensorflow_lite_support/cc/port/tflite_wrapper.h b/tensorflow_lite_support/cc/port/tflite_wrapper.h
new file mode 100644
index 00000000..601df9b4
--- /dev/null
+++ b/tensorflow_lite_support/cc/port/tflite_wrapper.h
@@ -0,0 +1,21 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_TFLITE_WRAPPER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_PORT_TFLITE_WRAPPER_H_
+
+#include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h"
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_TFLITE_WRAPPER_H_
diff --git a/tensorflow_lite_support/cc/task/README.md b/tensorflow_lite_support/cc/task/README.md
new file mode 100644
index 00000000..bd756a2e
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/README.md
@@ -0,0 +1,384 @@
+# TFLite Task library - C++
+
+A flexible and ready-to-use library for common machine learning model types,
+such as classification and detection.
+
+## Text Task Libraries
+
+### QuestionAnswerer
+
+`QuestionAnswerer` API is able to load
+[Mobile BERT](https://tfhub.dev/tensorflow/mobilebert/1) or
+[AlBert](https://tfhub.dev/tensorflow/albert_lite_base/1) TFLite models and
+answer question based on context.
+
+Use the C++ API to answer questions as follows:
+
+```cc
+using tflite::task::text::qa::BertQuestionAnswerer;
+using tflite::task::text::qa::QaAnswer;
+// Create API handler with Mobile Bert model.
+auto qa_client = BertQuestionAnswerer::CreateBertQuestionAnswererFromFile("/path/to/mobileBertModel", "/path/to/vocab");
+// Or create API handler with Albert model.
+// auto qa_client = BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile("/path/to/alBertModel", "/path/to/sentencePieceModel");
+
+
+std::string context =
+ "Nikola Tesla (Serbian Cyrillic: Никола Тесла; 10 "
+ "July 1856 – 7 January 1943) was a Serbian American inventor, electrical "
+ "engineer, mechanical engineer, physicist, and futurist best known for his "
+ "contributions to the design of the modern alternating current (AC) "
+ "electricity supply system.";
+std::string question = "When was Nikola Tesla born?";
+// Run inference with `context` and a given `question` to the context, and get top-k
+// answers ranked by logits.
+const std::vector<QaAnswer> answers = qa_client->Answer(context, question);
+// Access QaAnswer results.
+for (const QaAnswer& item : answers) {
+ std::cout << absl::StrFormat("Text: %s logit=%f start=%d end=%d", item.text,
+ item.pos.logit, item.pos.start, item.pos.end)
+ << std::endl;
+}
+// Output:
+// Text: 10 July 1856 logit=16.8527 start=17 end=19
+// ... (and more)
+//
+// So the top-1 answer is: "10 July 1856".
+```
+
+In the above code, `item.text` is the text content of an answer. We use a span
+with closed interval `[item.pos.start, item.pos.end]` to denote predicted tokens
+in the answer, and `item.pos.logit` is the sum of span logits to represent the
+confidence score.
+
+### NLClassifier
+
+`NLClassifier` API is able to load any TFLite models for natural language
+classaification task such as language detection or sentiment detection.
+
+The API expects a TFLite model with the following input/output tensor:
+Input tensor0:
+ (kTfLiteString) - input of the model, accepts a string.
+Output tensor0:
+ (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64)
+ - output scores for each class, if type is one of the Int types,
+ dequantize it to double
+Output tensor1: optional
+ (kTfLiteString)
+ - output classname for each class, should be of the same length with
+ scores. If this tensor is not present, the API uses score indices as
+ classnames.
+By default the API tries to find the input/output tensors with default
+configurations in NLClassifierOptions, with tensor name prioritized over
+tensor index. The option is configurable for different TFLite models.
+
+Use the C++ API to perform language ID classification as follows:
+
+```cc
+using tflite::task::text::nlclassifier::NLClassifier;
+using tflite::task::core::Category;
+auto classifier = NLClassifier::CreateFromFileAndOptions("/path/to/model");
+// Or create a customized NLClassifierOptions
+// NLClassifierOptions options =
+// {
+// .output_score_tensor_name = myOutputScoreTensorName,
+// .output_label_tensor_name = myOutputLabelTensorName,
+// }
+// auto classifier = NLClassifier::CreateFromFileAndOptions("/path/to/model", options);
+std::string context = "What language is this?";
+std::vector<Category> categories = classifier->Classify(context);
+// Access category results.
+for (const Categoryr& category : categories) {
+ std::cout << absl::StrFormat("Language: %s Probability: %f", category.class_name, category_.score)
+ << std::endl;
+}
+// Output:
+// Language: en Probability=0.9
+// ... (and more)
+//
+// So the top-1 answer is 'en'.
+```
+
+## Vision Task Libraries
+
+### Image Classifier
+
+`ImageClassifier` accepts any TFLite image classification model (with optional,
+but strongly recommended, TFLite Model Metadata) that conforms to the following
+spec:
+
+Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`):
+
+ - image input of size `[batch x height x width x channels]`.
+ - batch inference is not supported (`batch` is required to be 1).
+ - only RGB inputs are supported (`channels` is required to be 3).
+ - if type is `kTfLiteFloat32`, `NormalizationOptions` are required to be
+ attached to the metadata for input normalization.
+
+At least one output tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`) with:
+
+ - `N` classes and either 2 or 4 dimensions, i.e. `[1 x N]` or
+ `[1 x 1 x 1 x N]`
+ - optional (but recommended) label map(s) as AssociatedFile-s with type
+ TENSOR_AXIS_LABELS, containing one label per line. The first such
+ AssociatedFile (if any) is used to fill the `class_name` field of the
+ results. The `display_name` field is filled from the AssociatedFile (if
+ any) whose locale matches the `display_names_locale` field of the
+ `ImageClassifierOptions` used at creation time ("en" by default, i.e.
+ English). If none of these are available, only the `index` field of the
+ results will be filled.
+
+An example of such model can be found at:
+https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1
+
+Example usage:
+
+```cc
+// More options are available (e.g. max number of results to return). At the
+// very least, the model must be specified:
+ImageClassifierOptions options;
+options.mutable_model_file_with_metadata()->set_file_name(
+ "/path/to/model.tflite");
+
+// Create an ImageClassifier instance from the options.
+StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
+ ImageClassifier::CreateFromOptions(options);
+// Check if an error occurred.
+if (!image_classifier_or.ok()) {
+ std::cerr << "An error occurred during ImageClassifier creation: "
+ << image_classifier_or.status().message();
+ return;
+}
+std::unique_ptr<ImageClassifier> image_classifier =
+ std::move(image_classifier_or.value());
+
+// Prepare FrameBuffer input from e.g. image RGBA data, width and height:
+std::unique_ptr<FrameBuffer> frame_buffer =
+ CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height});
+
+// Run inference:
+StatusOr<ClassificationResult> result_or =
+ image_classifier->Classify(*frame_buffer);
+// Check if an error occurred.
+if (!result_or.ok()) {
+ std::cerr << "An error occurred during classification: "
+ << result_or.status().message();
+ return;
+}
+ClassificationResult result = result_or.value();
+
+// Example value for 'result':
+//
+// classifications {
+// classes { index: 934 score: 0.95 class_name: "cat" }
+// classes { index: 948 score: 0.007 class_name: "dog" }
+// classes { index: 927 score: 0.003 class_name: "fox" }
+// head_index: 0
+// }
+```
+
+A CLI demo tool is also available [here][1] for easily trying out this API.
+
+### Object Detector
+
+`ObjectDetector` accepts any object detection TFLite model (with mandatory
+TFLite Model Metadata) that conforms to the following spec (e.g. Single Shot
+Detectors):
+
+Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`):
+
+ - image input of size `[batch x height x width x channels]`.
+ - batch inference is not supported (`batch` is required to be 1).
+ - only RGB inputs are supported (`channels` is required to be 3).
+ - if type is kTfLiteFloat32, `NormalizationOptions` are required to be
+ attached to the metadata for input normalization.
+
+Output tensors must be the 4 outputs (type: `kTfLiteFloat32`) of a
+[`DetectionPostProcess`][2] op, i.e:
+
+* Locations:
+
+ - of size `[num_results x 4]`, the inner array
+ representing bounding boxes in the form [top, left, right, bottom].
+ - BoundingBoxProperties are required to be attached to the metadata
+ and must specify type=BOUNDARIES and coordinate_type=RATIO.
+
+* Classes:
+
+ - of size `[num_results]`, each value representing the
+ integer index of a class.
+ - optional (but recommended) label map(s) can be attached as
+ AssociatedFile-s with type TENSOR_VALUE_LABELS, containing one label per
+ line. The first such AssociatedFile (if any) is used to fill the
+ `class_name` field of the results. The `display_name` field is filled
+ from the AssociatedFile (if any) whose locale matches the
+ `display_names_locale` field of the `ObjectDetectorOptions` used at
+ creation time ("en" by default, i.e. English). If none of these are
+ available, only the `index` field of the results will be filled.
+
+* Scores:
+
+ - of size `[num_results]`, each value representing the score
+ of the detected object.
+
+* Number of results:
+
+ - integer `num_results` as a tensor of size `[1]`
+
+An example of such model can be found at:
+https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1
+
+Example usage:
+
+```cc
+// More options are available (e.g. max number of results to return). At the
+// very least, the model must be specified:
+ObjectDetectorOptions options;
+options.mutable_model_file_with_metadata()->set_file_name(
+ "/path/to/model.tflite");
+
+// Create an ObjectDetector instance from the options.
+StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
+ ObjectDetector::CreateFromOptions(options);
+// Check if an error occurred.
+if (!object_detector_or.ok()) {
+ std::cerr << "An error occurred during ObjectDetector creation: "
+ << object_detector_or.status().message();
+ return;
+}
+std::unique_ptr<ObjectDetector> object_detector =
+ std::move(object_detector_or.value());
+
+// Prepare FrameBuffer input from e.g. image RGBA data, width and height:
+std::unique_ptr<FrameBuffer> frame_buffer =
+ CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height});
+
+// Run inference:
+StatusOr<DetectionResult> result_or = object_detector->Detect(*frame_buffer);
+// Check if an error occurred.
+if (!result_or.ok()) {
+ std::cerr << "An error occurred during detection: "
+ << result_or.status().message();
+ return;
+}
+DetectionResult result = result_or.value();
+
+// Example value for 'result':
+//
+// detections {
+// bounding_box {
+// origin_x: 54
+// origin_y: 398
+// width: 393
+// height: 196
+// }
+// classes { index: 16 score: 0.65 class_name: "cat" }
+// }
+// detections {
+// bounding_box {
+// origin_x: 602
+// origin_y: 157
+// width: 394
+// height: 447
+// }
+// classes { index: 17 score: 0.45 class_name: "dog" }
+// }
+```
+
+A CLI demo tool is available [here][3] for easily trying out this API.
+
+### Image Segmenter
+
+`ImageSegmenter` accepts any TFLite model (with optional, but strongly
+recommended, TFLite Model Metadata) that conforms to the following spec:
+
+Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`):
+
+ - image input of size `[batch x height x width x channels]`.
+ - batch inference is not supported (`batch` is required to be 1).
+ - only RGB inputs are supported (`channels` is required to be 3).
+ - if type is kTfLiteFloat32, `NormalizationOptions` are required to be
+ attached to the metadata for input normalization.
+
+Output tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`):
+
+ - tensor of size `[batch x mask_height x mask_width x num_classes]`, where
+ `batch` is required to be 1, `mask_width` and `mask_height` are the
+ dimensions of the segmentation masks produced by the model, and
+ `num_classes` is the number of classes supported by the model.
+ - optional (but recommended) label map(s) can be attached as
+ AssociatedFile-s with type TENSOR_AXIS_LABELS, containing one label per
+ line. The first such AssociatedFile (if any) is used to fill the
+ `class_name` field of the results. The `display_name` field is filled
+ from the AssociatedFile (if any) whose locale matches the
+ `display_names_locale` field of the `ImageSegmenterOptions` used at
+ creation time ("en" by default, i.e. English). If none of these are
+ available, only the `index` field of the results will be filled.
+
+An example of such model can be found at:
+https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1
+
+Example usage:
+
+```cc
+// More options are available to select between return a single category mask
+// or multiple confidence masks during post-processing.
+ImageSegmenterOptions options;
+options.mutable_model_file_with_metadata()->set_file_name(
+ "/path/to/model.tflite");
+
+// Create an ImageSegmenter instance from the options.
+StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
+ ImageSegmenter::CreateFromOptions(options);
+// Check if an error occurred.
+if (!image_segmenter_or.ok()) {
+ std::cerr << "An error occurred during ImageSegmenter creation: "
+ << image_segmenter_or.status().message();
+ return;
+}
+std::unique_ptr<ImageSegmenter> immage_segmenter =
+ std::move(image_segmenter_or.value());
+
+// Prepare FrameBuffer input from e.g. image RGBA data, width and height:
+std::unique_ptr<FrameBuffer> frame_buffer =
+ CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height});
+
+// Run inference:
+StatusOr<SegmentationResult> result_or =
+ immage_segmenter->Segment(*frame_buffer);
+// Check if an error occurred.
+if (!result_or.ok()) {
+ std::cerr << "An error occurred during segmentation: "
+ << result_or.status().message();
+ return;
+}
+SegmentationResult result = result_or.value();
+
+// Example value for 'result':
+//
+// segmentation {
+// width: 257
+// height: 257
+// category_mask: "\x00\x01..."
+// colored_labels { r: 0 g: 0 b: 0 class_name: "background" }
+// colored_labels { r: 128 g: 0 b: 0 class_name: "aeroplane" }
+// ...
+// colored_labels { r: 128 g: 192 b: 0 class_name: "train" }
+// colored_labels { r: 0 g: 64 b: 128 class_name: "tv" }
+// }
+//
+// Where 'category_mask' is a byte buffer of size 'width' x 'height', with the
+// value of each pixel representing the class this pixel belongs to (e.g. '\x00'
+// means "background", '\x01' means "aeroplane", etc).
+// 'colored_labels' provides the label for each possible value, as well as
+// suggested RGB components to optionally transform the result into a more
+// human-friendly colored image.
+//
+```
+
+A CLI demo tool is available [here][4] for easily trying out this API.
+
+[1]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
+[2]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
+[3]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
+[4]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
diff --git a/tensorflow_lite_support/cc/task/core/BUILD b/tensorflow_lite_support/cc/task/core/BUILD
new file mode 100644
index 00000000..1995dfe3
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/BUILD
@@ -0,0 +1,156 @@
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "tflite_engine",
+ srcs = ["tflite_engine.cc"],
+ hdrs = ["tflite_engine.h"],
+ deps = [
+ ":external_file_handler",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
+ # The dependency on builtin_ops here is only for the default
+ # value of the OpResolver parameter:
+ # std::unique_ptr<tflite::IterableOpResolver> resolver =
+ # absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()
+ # When linking statically, if the client of this library doesn't use
+ # the default argument, this dependency does not cause all the builtin ops
+ # to get included in the executable.
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ "@org_tensorflow//tensorflow/lite/tools:verifier",
+ ] + select({
+ "//tensorflow_lite_support/cc:tflite_use_c_api": [
+ "@org_tensorflow//tensorflow/lite/core/api:verifier",
+ "@org_tensorflow//tensorflow/lite/c:c_api",
+ "@org_tensorflow//tensorflow/lite/c:c_api_experimental",
+ "@org_tensorflow//tensorflow/lite:kernel_api",
+ "@org_tensorflow//tensorflow/lite:stderr_reporter",
+ ],
+ "//conditions:default": [
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite:kernel_api",
+ ],
+ }) + [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:tflite_wrapper",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ ],
+)
+
+# This is a duplicate of the above 'tflite_engine' target that is used for
+# testing with TFLITE_USE_C_API defined. It should be the same as the target
+# above, except that it adds
+# testonly = 1,
+# defines = ["TFLITE_USE_C_API"],
+# and that it resolves the conditional deps from the 'select' as if
+# "//tensorflow_lite_support/cc:tflite_use_c_api" was enabled.
+# This allows testing the TFLITE_USE_C_API case even when
+# '--copt=-DTFLITE_USE_C_API' wasn't passed on the build command line.
+cc_library(
+ name = "tflite_engine_with_c_api_for_test",
+ testonly = 1,
+ srcs = ["tflite_engine.cc"],
+ hdrs = ["tflite_engine.h"],
+ defines = ["TFLITE_USE_C_API"],
+ deps = [
+ ":external_file_handler",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/core/api",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ "@org_tensorflow//tensorflow/lite/tools:verifier",
+ ] + [
+ "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
+ "@org_tensorflow//tensorflow/lite/core/api:verifier",
+ "@org_tensorflow//tensorflow/lite/c:c_api",
+ "@org_tensorflow//tensorflow/lite/c:c_api_experimental",
+ "@org_tensorflow//tensorflow/lite:kernel_api",
+ "@org_tensorflow//tensorflow/lite:stderr_reporter",
+ ] + [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:tflite_wrapper_with_c_api_for_test",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ ],
+)
+
+cc_library(
+ name = "base_task_api",
+ hdrs = ["base_task_api.h"],
+ deps = [
+ ":tflite_engine",
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/port:tflite_wrapper",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ ],
+)
+
+cc_library(
+ name = "task_api_factory",
+ hdrs = ["task_api_factory.h"],
+ deps = [
+ ":base_task_api",
+ ":tflite_engine",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ "@com_google_absl//absl/status",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
+ "@org_tensorflow//tensorflow/lite/kernels:op_macros",
+ ],
+)
+
+cc_library(
+ name = "task_utils",
+ srcs = ["task_utils.cc"],
+ hdrs = ["task_utils.h"],
+ deps = [
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:string_util",
+ "@org_tensorflow//tensorflow/lite:type_to_tflitetype",
+ "@org_tensorflow//tensorflow/lite/kernels:op_macros",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
+ ],
+)
+
+cc_library(
+ name = "category",
+ hdrs = ["category.h"],
+)
+
+cc_library(
+ name = "external_file_handler",
+ srcs = ["external_file_handler.cc"],
+ hdrs = ["external_file_handler.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
diff --git a/tensorflow_lite_support/cc/task/core/base_task_api.h b/tensorflow_lite_support/cc/task/core/base_task_api.h
new file mode 100644
index 00000000..a27f785b
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/base_task_api.h
@@ -0,0 +1,144 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_
+
+#include <utility>
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/port/tflite_wrapper.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+
+namespace tflite {
+namespace task {
+namespace core {
+
+class BaseUntypedTaskApi {
+ public:
+ explicit BaseUntypedTaskApi(std::unique_ptr<TfLiteEngine> engine)
+ : engine_{std::move(engine)} {}
+
+ virtual ~BaseUntypedTaskApi() = default;
+
+ TfLiteEngine* GetTfLiteEngine() { return engine_.get(); }
+ const TfLiteEngine* GetTfLiteEngine() const { return engine_.get(); }
+
+ const metadata::ModelMetadataExtractor* GetMetadataExtractor() const {
+ return engine_->metadata_extractor();
+ }
+
+ protected:
+ std::unique_ptr<TfLiteEngine> engine_;
+};
+
+template <class OutputType, class... InputTypes>
+class BaseTaskApi : public BaseUntypedTaskApi {
+ public:
+ explicit BaseTaskApi(std::unique_ptr<TfLiteEngine> engine)
+ : BaseUntypedTaskApi(std::move(engine)) {}
+ // BaseTaskApi is neither copyable nor movable.
+ BaseTaskApi(const BaseTaskApi&) = delete;
+ BaseTaskApi& operator=(const BaseTaskApi&) = delete;
+
+ // Cancels the current running TFLite invocation on CPU.
+ //
+ // Usually called on a different thread than the one inference is running on.
+ // Calling Cancel() will cause the underlying TFLite interpreter to return an
+ // error, which will turn into a `CANCELLED` status and empty results. Calling
+ // Cancel() at the other time will not take any effect on the current or
+ // following invocation. It is perfectly fine to run inference again on the
+ // same instance after a cancelled invocation. If the TFLite inference is
+ // partially delegated on CPU, logs a warning message and only cancels the
+ // invocation running on CPU. Other invocation which depends on the output of
+ // the CPU invocation will not be executed.
+ void Cancel() { engine_->Cancel(); }
+
+ protected:
+ // Subclasses need to populate input_tensors from api_inputs.
+ virtual absl::Status Preprocess(
+ const std::vector<TfLiteTensor*>& input_tensors,
+ InputTypes... api_inputs) = 0;
+
+ // Subclasses need to construct OutputType object from output_tensors.
+ // Original inputs are also provided as they may be needed.
+ virtual tflite::support::StatusOr<OutputType> Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ InputTypes... api_inputs) = 0;
+
+ // Returns (the addresses of) the model's inputs.
+ std::vector<TfLiteTensor*> GetInputTensors() { return engine_->GetInputs(); }
+
+ // Returns (the addresses of) the model's outputs.
+ std::vector<const TfLiteTensor*> GetOutputTensors() {
+ return engine_->GetOutputs();
+ }
+
+ // Performs inference using tflite::support::TfLiteInterpreterWrapper
+ // InvokeWithoutFallback().
+ tflite::support::StatusOr<OutputType> Infer(InputTypes... args) {
+ tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper =
+ engine_->interpreter_wrapper();
+ // Note: AllocateTensors() is already performed by the interpreter wrapper
+ // at InitInterpreter time (see TfLiteEngine).
+ RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...));
+ absl::Status status = interpreter_wrapper->InvokeWithoutFallback();
+ if (!status.ok()) {
+ return status.GetPayload(tflite::support::kTfLiteSupportPayload)
+ .has_value()
+ ? status
+ : tflite::support::CreateStatusWithPayload(status.code(),
+ status.message());
+ }
+ return Postprocess(GetOutputTensors(), args...);
+ }
+
+ // Performs inference using tflite::support::TfLiteInterpreterWrapper
+ // InvokeWithFallback() to benefit from automatic fallback from delegation to
+ // CPU where applicable.
+ tflite::support::StatusOr<OutputType> InferWithFallback(InputTypes... args) {
+ tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper =
+ engine_->interpreter_wrapper();
+ // Note: AllocateTensors() is already performed by the interpreter wrapper
+ // at InitInterpreter time (see TfLiteEngine).
+ RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...));
+ auto set_inputs_nop =
+ [](tflite::task::core::TfLiteEngine::Interpreter* interpreter)
+ -> absl::Status {
+ // NOP since inputs are populated at Preprocess() time.
+ return absl::OkStatus();
+ };
+ absl::Status status =
+ interpreter_wrapper->InvokeWithFallback(set_inputs_nop);
+ if (!status.ok()) {
+ return status.GetPayload(tflite::support::kTfLiteSupportPayload)
+ .has_value()
+ ? status
+ : tflite::support::CreateStatusWithPayload(status.code(),
+ status.message());
+ }
+ return Postprocess(GetOutputTensors(), args...);
+ }
+};
+
+} // namespace core
+} // namespace task
+} // namespace tflite
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_
diff --git a/tensorflow_lite_support/cc/task/core/category.h b/tensorflow_lite_support/cc/task/core/category.h
new file mode 100644
index 00000000..a99f994c
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/category.h
@@ -0,0 +1,44 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_
+#include <string>
+
+namespace tflite {
+namespace task {
+namespace core {
+
+// Result for classification APIs.
+struct Category {
+ std::string class_name;
+ double score;
+ Category(const std::string& class_name, double score)
+ : class_name(class_name), score(score) {}
+
+ friend bool operator==(const Category& lhs, const Category& rhs) {
+ return lhs.score == rhs.score && lhs.class_name == rhs.class_name;
+ }
+
+ friend bool operator!=(const Category& lhs, const Category& rhs) {
+ return !(lhs == rhs);
+ }
+};
+
+} // namespace core
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_
diff --git a/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/tensorflow_lite_support/cc/task/core/external_file_handler.cc
new file mode 100644
index 00000000..e2150c13
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/external_file_handler.cc
@@ -0,0 +1,194 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <stddef.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+
+namespace tflite {
+namespace task {
+namespace core {
+namespace {
+
+using ::absl::StatusCode;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+
+// Gets the offset aligned to page size for mapping given files into memory by
+// file descriptor correctly, as according to mmap(2), the offset used in mmap
+// must be a multiple of sysconf(_SC_PAGE_SIZE).
+int64 GetPageSizeAlignedOffset(int64 offset) {
+ int64 aligned_offset = offset;
+ int64 page_size = sysconf(_SC_PAGE_SIZE);
+ if (offset % page_size != 0) {
+ aligned_offset = offset / page_size * page_size;
+ }
+ return aligned_offset;
+}
+
+} // namespace
+
+/* static */
+StatusOr<std::unique_ptr<ExternalFileHandler>>
+ExternalFileHandler::CreateFromExternalFile(const ExternalFile* external_file) {
+ // Use absl::WrapUnique() to call private constructor:
+ // https://abseil.io/tips/126.
+ std::unique_ptr<ExternalFileHandler> handler =
+ absl::WrapUnique(new ExternalFileHandler(external_file));
+
+ RETURN_IF_ERROR(handler->MapExternalFile());
+
+ return handler;
+}
+
+absl::Status ExternalFileHandler::MapExternalFile() {
+ if (!external_file_.file_content().empty()) {
+ return absl::OkStatus();
+ }
+ if (external_file_.file_name().empty() &&
+ !external_file_.has_file_descriptor_meta()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "ExternalFile must specify at least one of 'file_content', file_name' "
+ "or 'file_descriptor_meta'.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ // Obtain file descriptor, offset and size.
+ int fd = -1;
+ if (!external_file_.file_name().empty()) {
+ owned_fd_ = open(external_file_.file_name().c_str(), O_RDONLY);
+ if (owned_fd_ < 0) {
+ const std::string error_message = absl::StrFormat(
+ "Unable to open file at %s", external_file_.file_name());
+ switch (errno) {
+ case ENOENT:
+ return CreateStatusWithPayload(
+ StatusCode::kNotFound, error_message,
+ TfLiteSupportStatus::kFileNotFoundError);
+ case EACCES:
+ case EPERM:
+ return CreateStatusWithPayload(
+ StatusCode::kPermissionDenied, error_message,
+ TfLiteSupportStatus::kFilePermissionDeniedError);
+ case EINTR:
+ return CreateStatusWithPayload(StatusCode::kUnavailable,
+ error_message,
+ TfLiteSupportStatus::kFileReadError);
+ case EBADF:
+ return CreateStatusWithPayload(StatusCode::kFailedPrecondition,
+ error_message,
+ TfLiteSupportStatus::kFileReadError);
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown,
+ absl::StrFormat("%s, errno=%d", error_message, errno),
+ TfLiteSupportStatus::kFileReadError);
+ }
+ }
+ fd = owned_fd_;
+ } else {
+ fd = external_file_.file_descriptor_meta().fd();
+ if (fd < 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Provided file descriptor is invalid: %d < 0", fd),
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ buffer_offset_ = external_file_.file_descriptor_meta().offset();
+ buffer_size_ = external_file_.file_descriptor_meta().length();
+ }
+ // Get actual file size. Always use 0 as offset to lseek(2) to get the actual
+ // file size, as SEEK_END returns the size of the file *plus* offset.
+ size_t file_size = lseek(fd, /*offset=*/0, SEEK_END);
+ if (file_size <= 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown,
+ absl::StrFormat("Unable to get file size, errno=%d", errno),
+ TfLiteSupportStatus::kFileReadError);
+ }
+ // Deduce buffer size if not explicitly provided through file descriptor.
+ if (buffer_size_ <= 0) {
+ buffer_size_ = file_size - buffer_offset_;
+ }
+ // Check for out of range issues.
+ if (file_size <= buffer_offset_) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Provided file offset (%d) exceeds or matches actual "
+ "file length (%d)",
+ buffer_offset_, file_size),
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (file_size < buffer_size_ + buffer_offset_) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Provided file length + offset (%d) exceeds actual "
+ "file length (%d)",
+ buffer_size_ + buffer_offset_, file_size),
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ // If buffer_offset_ is not multiple of sysconf(_SC_PAGE_SIZE), align with
+ // extra leading bytes and adjust buffer_size_ to account for the extra
+ // leading bytes.
+ buffer_aligned_offset_ = GetPageSizeAlignedOffset(buffer_offset_);
+ buffer_aligned_size_ = buffer_size_ + buffer_offset_ - buffer_aligned_offset_;
+ // Map into memory.
+ buffer_ = mmap(/*addr=*/nullptr, buffer_aligned_size_, PROT_READ, MAP_SHARED,
+ fd, buffer_aligned_offset_);
+ if (buffer_ == MAP_FAILED) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown,
+ absl::StrFormat("Unable to map file to memory buffer, errno=%d", errno),
+ TfLiteSupportStatus::kFileMmapError);
+ }
+ return absl::OkStatus();
+}
+
+absl::string_view ExternalFileHandler::GetFileContent() {
+ if (!external_file_.file_content().empty()) {
+ return external_file_.file_content();
+ } else {
+ return absl::string_view(static_cast<const char*>(buffer_) +
+ buffer_offset_ - buffer_aligned_offset_,
+ buffer_size_);
+ }
+}
+
+ExternalFileHandler::~ExternalFileHandler() {
+ if (buffer_ != MAP_FAILED) {
+ munmap(buffer_, buffer_aligned_size_);
+ }
+ if (owned_fd_ >= 0) {
+ close(owned_fd_);
+ }
+}
+
+} // namespace core
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/core/external_file_handler.h b/tensorflow_lite_support/cc/task/core/external_file_handler.h
new file mode 100644
index 00000000..236d9034
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/external_file_handler.h
@@ -0,0 +1,94 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_EXTERNAL_FILE_HANDLER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_EXTERNAL_FILE_HANDLER_H_
+
+#include <memory>
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+
+namespace tflite {
+namespace task {
+namespace core {
+
+// Handler providing easy access to the contents of a file specified by an
+// ExternalFile proto [1]. Takes care (if needed, depending on the provided
+// proto fields) of opening and/or mapping the file in memory at creation time,
+// as well as closing and/or unmapping at destruction time.
+//
+// [1]: support/c/task/core/proto/external_file.proto
+class ExternalFileHandler {
+ public:
+ // Creates an ExternalFileHandler from the input ExternalFile proto and
+ // returns a pointer to the new object. Ownership is transferred to the
+ // caller. Returns an error if the creation failed, which may happen if the
+ // provided ExternalFile can't be opened or mapped into memory.
+ //
+ // Warning: Does not take ownership of `external_file`, which must refer to a
+ // valid proto that outlives this object.
+ static tflite::support::StatusOr<std::unique_ptr<ExternalFileHandler>>
+ CreateFromExternalFile(const ExternalFile* external_file);
+
+ ~ExternalFileHandler();
+
+ // Returns the content of the ExternalFile as a string_view guaranteed to be
+ // valid as long as the ExternalFileHandler is alive.
+ absl::string_view GetFileContent();
+
+ private:
+ // Private constructor, called from CreateFromExternalFile().
+ explicit ExternalFileHandler(const ExternalFile* external_file)
+ : external_file_(*external_file) {}
+
+ // Opens (if provided by path) and maps (if provided by path or file
+ // descriptor) the external file in memory. Does nothing otherwise, as file
+ // contents are already loaded in memory.
+ absl::Status MapExternalFile();
+
+ // Reference to the input ExternalFile.
+ const ExternalFile& external_file_;
+
+ // The file descriptor of the ExternalFile if provided by path, as it is
+ // opened and owned by this class. Set to -1 otherwise.
+ int owned_fd_{-1};
+
+ // Points to the memory buffer mapped from the file descriptor of the
+ // ExternalFile, if provided by path or file descriptor.
+ void* buffer_{};
+
+ // The mapped memory buffer offset, if any.
+ int64 buffer_offset_{};
+ // The size in bytes of the mapped memory buffer, if any.
+ int64 buffer_size_{};
+
+ // As mmap(2) requires the offset to be a multiple of sysconf(_SC_PAGE_SIZE):
+
+ // The aligned mapped memory buffer offset, if any.
+ int64 buffer_aligned_offset_{};
+ // The aligned mapped memory buffer size in bytes taking into account the
+ // offset shift introduced by buffer_aligned_memory_offset_, if any.
+ int64 buffer_aligned_size_{};
+};
+
+} // namespace core
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_EXTERNAL_FILE_HANDLER_H_
diff --git a/tensorflow_lite_support/cc/task/core/proto/BUILD b/tensorflow_lite_support/cc/task/core/proto/BUILD
new file mode 100644
index 00000000..7418e5b2
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/proto/BUILD
@@ -0,0 +1,27 @@
+load("//tensorflow_lite_support/cc/port:build_defs.bzl", "support_cc_proto_library")
+
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+proto_library(
+ name = "external_file_proto",
+ srcs = ["external_file.proto"],
+)
+
+support_cc_proto_library(
+ name = "external_file_cc_proto",
+ srcs = ["external_file.proto"],
+ deps = [
+ ":external_file_proto",
+ ],
+)
+
+cc_library(
+ name = "external_file_proto_inc",
+ hdrs = ["external_file_proto_inc.h"],
+ deps = [":external_file_cc_proto"],
+)
diff --git a/tensorflow_lite_support/cc/task/core/proto/external_file.proto b/tensorflow_lite_support/cc/task/core/proto/external_file.proto
new file mode 100644
index 00000000..c0a42124
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/proto/external_file.proto
@@ -0,0 +1,67 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto2";
+
+package tflite.task.core;
+
+
+// Represents external files used by the Task APIs (e.g. TF Lite FlatBuffer or
+// plain-text labels file). The files can be specified by one of the following
+// three ways:
+//
+// (1) file contents loaded in `file_content`.
+// (2) file path in `file_name`.
+// (3) file descriptor through `file_descriptor_meta` as returned by open(2).
+//
+// If more than one field of these fields is provided, they are used in this
+// precedence order.
+// Next id: 5
+message ExternalFile {
+ // The path to the file to open and mmap in memory
+ optional string file_name = 1;
+
+ // The file contents as a byte array.
+ optional bytes file_content = 2;
+
+ // The file descriptor to a file opened with open(2), with optional additional
+ // offset and length information.
+ optional FileDescriptorMeta file_descriptor_meta = 4;
+
+ // Deprecated field numbers.
+ reserved 3;
+}
+
+// A proto defining file descriptor metadata for mapping file into memory using
+// mmap(2).
+message FileDescriptorMeta {
+ // File descriptor as returned by open(2).
+ optional int32 fd = 1;
+
+ // Optional length of the mapped memory. If not specified, the actual file
+ // size is used at runtime.
+ //
+ // This is an advanced option, e.g. this can be used on Android to specify the
+ // length of a given asset obtained from AssetFileDescriptor#getLength().
+ optional int64 length = 2;
+
+ // Optional starting offset in the file referred to by the file descriptor
+ // `fd`.
+ //
+ // This is an advanced option, e.g. this can be used on Android to specify the
+ // offset of a given asset obtained from AssetFileDescriptor#getStartOffset().
+ optional int64 offset = 3;
+}
+
diff --git a/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h b/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h
new file mode 100644
index 00000000..017aa651
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h
@@ -0,0 +1,20 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_EXTERNAL_FILE_PROTO_INC_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_EXTERNAL_FILE_PROTO_INC_H_
+
+#include "tensorflow_lite_support/cc/task/core/proto/external_file.pb.h"
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_EXTERNAL_FILE_PROTO_INC_H_
diff --git a/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt b/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt
new file mode 100644
index 00000000..dafb0fde
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt
@@ -0,0 +1,2 @@
+allow_all: true
+optimize_mode: LITE_RUNTIME
diff --git a/tensorflow_lite_support/cc/task/core/task_api_factory.h b/tensorflow_lite_support/cc/task/core/task_api_factory.h
new file mode 100644
index 00000000..06c3a012
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/task_api_factory.h
@@ -0,0 +1,100 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_API_FACTORY_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_API_FACTORY_H_
+
+#include <memory>
+
+#include "absl/status/status.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+
+namespace tflite {
+namespace task {
+namespace core {
+template <typename T>
+using EnableIfBaseUntypedTaskApiSubclass = typename std::enable_if<
+ std::is_base_of<BaseUntypedTaskApi, T>::value>::type*;
+
+// Template creator for all subclasses of BaseTaskApi
+class TaskAPIFactory {
+ public:
+ TaskAPIFactory() = delete;
+
+ template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
+ static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBuffer(
+ const char* buffer_data, size_t buffer_size,
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ int num_threads = 1) {
+ auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver));
+ RETURN_IF_ERROR(engine->BuildModelFromFlatBuffer(buffer_data, buffer_size));
+ return CreateFromTfLiteEngine<T>(std::move(engine), num_threads);
+ }
+
+ template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
+ static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromFile(
+ const string& file_name,
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ int num_threads = 1) {
+ auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver));
+ RETURN_IF_ERROR(engine->BuildModelFromFile(file_name));
+ return CreateFromTfLiteEngine<T>(std::move(engine), num_threads);
+ }
+
+ template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
+ static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromFileDescriptor(
+ int file_descriptor,
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ int num_threads = 1) {
+ auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver));
+ RETURN_IF_ERROR(engine->BuildModelFromFileDescriptor(file_descriptor));
+ return CreateFromTfLiteEngine<T>(std::move(engine), num_threads);
+ }
+
+ template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
+ static tflite::support::StatusOr<std::unique_ptr<T>>
+ CreateFromExternalFileProto(
+ const ExternalFile* external_file,
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ int num_threads = 1) {
+ auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver));
+ RETURN_IF_ERROR(engine->BuildModelFromExternalFileProto(external_file));
+ return CreateFromTfLiteEngine<T>(std::move(engine), num_threads);
+ }
+
+ private:
+ template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr>
+ static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine(
+ std::unique_ptr<TfLiteEngine> engine, int num_threads) {
+ RETURN_IF_ERROR(engine->InitInterpreter(num_threads));
+ return absl::make_unique<T>(std::move(engine));
+ }
+};
+
+} // namespace core
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_API_FACTORY_H_
diff --git a/tensorflow_lite_support/cc/task/core/task_utils.cc b/tensorflow_lite_support/cc/task/core/task_utils.cc
new file mode 100644
index 00000000..de733ae9
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/task_utils.cc
@@ -0,0 +1,66 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/core/task_utils.h"
+
+#include <fstream>
+
+#include "absl/strings/str_cat.h"
+
+namespace tflite {
+namespace task {
+namespace core {
+
+double Dequantize(const TfLiteTensor& tensor, int index) {
+ int32_t quantized_value = 0;
+ switch (tensor.type) {
+ case kTfLiteUInt8:
+ quantized_value = GetTensorData<uint8_t>(&tensor)[index];
+ break;
+ case kTfLiteInt8:
+ quantized_value = GetTensorData<int8_t>(&tensor)[index];
+ break;
+ case kTfLiteInt16:
+ quantized_value = GetTensorData<int16_t>(&tensor)[index];
+ break;
+ default:
+ TF_LITE_FATAL(
+ absl::StrCat(
+ "Invalid tensor type for dequantization ", tensor.name,
+ ". Requested kTfLiteUInt8, kTfLiteInt8 or kTfLiteInt16, got ",
+ TfLiteTypeGetName(tensor.type), ".")
+ .c_str());
+ }
+ return tensor.params.scale * (quantized_value - tensor.params.zero_point);
+}
+
+std::string GetStringAtIndex(const TfLiteTensor* labels, int index) {
+ const auto& strref = tflite::GetString(labels, index);
+ return std::string(strref.str, strref.len);
+}
+
+std::string LoadBinaryContent(const char* filename) {
+ std::ifstream input_file(filename, std::ios::binary | std::ios::ate);
+ // Find buffer size from input file, and load the buffer.
+ size_t buffer_size = input_file.tellg();
+ std::string buffer(buffer_size, '\0');
+ input_file.seekg(0, std::ios::beg);
+ input_file.read(const_cast<char*>(buffer.c_str()), buffer_size);
+ return buffer;
+}
+
+} // namespace core
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/core/task_utils.h b/tensorflow_lite_support/cc/task/core/task_utils.h
new file mode 100644
index 00000000..c1c3fc31
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/task_utils.h
@@ -0,0 +1,182 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_
+
+#include <algorithm>
+#include <cstring>
+#include <numeric>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/string_util.h"
+#include "tensorflow/lite/type_to_tflitetype.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace core {
+
+// Checks if data type of tensor is T and returns the pointer casted to T if
+// applicable, returns nullptr if tensor type is not T.
+// See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType.
+template <typename T>
+T* TypedTensor(const TfLiteTensor* tensor_ptr) {
+ if (tensor_ptr->type == typeToTfLiteType<T>()) {
+ return reinterpret_cast<T*>(tensor_ptr->data.raw);
+ }
+ return nullptr;
+}
+
+// Checks and returns type of a tensor, fails if tensor type is not T.
+template <typename T>
+T* AssertAndReturnTypedTensor(const TfLiteTensor* tensor) {
+ if (T* v = TypedTensor<T>(tensor)) return v;
+ // TODO(b/150903834): throw exceptions instead
+ TF_LITE_ASSERT(tensor->data.raw);
+ TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name,
+ ". Requested ",
+ TfLiteTypeGetName(typeToTfLiteType<T>()), ", got ",
+ TfLiteTypeGetName(tensor->type), ".")
+ .c_str());
+}
+
+// Populates tensor with array of data, fails if data type doesn't match tensor
+// type or has not the same number of elements.
+template <typename T>
+inline void PopulateTensor(const T* data, int num_elements,
+ TfLiteTensor* tensor) {
+ T* v = AssertAndReturnTypedTensor<T>(tensor);
+ size_t bytes = num_elements * sizeof(T);
+ // TODO(b/150903834): throw exceptions instead
+ TF_LITE_ASSERT(tensor->bytes == bytes);
+ memcpy(v, data, bytes);
+}
+
+// Populates tensor with vector of data, fails if data type doesn't match tensor
+// type or has not the same number of elements.
+template <typename T>
+inline void PopulateTensor(const std::vector<T>& data, TfLiteTensor* tensor) {
+ return PopulateTensor<T>(data.data(), data.size(), tensor);
+}
+
+template <>
+inline void PopulateTensor<std::string>(const std::vector<std::string>& data,
+ TfLiteTensor* tensor) {
+ if (tensor->type != kTfLiteString) {
+ TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name,
+ ". Requested STRING, got ",
+ TfLiteTypeGetName(tensor->type), ".")
+ .c_str());
+ }
+ tflite::DynamicBuffer input_buf;
+ for (const auto& value : data) {
+ input_buf.AddString(value.data(), value.length());
+ }
+ input_buf.WriteToTensorAsVector(tensor);
+}
+
+// Populates tensor one data item, fails if data type doesn't match tensor
+// type.
+template <typename T>
+inline void PopulateTensor(const T& data, TfLiteTensor* tensor) {
+ T* v = AssertAndReturnTypedTensor<T>(tensor);
+ *v = data;
+}
+
+template <>
+inline void PopulateTensor<std::string>(const std::string& data,
+ TfLiteTensor* tensor) {
+ tflite::DynamicBuffer input_buf;
+ input_buf.AddString(data.data(), data.length());
+ input_buf.WriteToTensorAsVector(tensor);
+}
+
+// Populates a vector from the tensor, fails if data type doesn't match tensor
+// type.
+template <typename T>
+inline void PopulateVector(const TfLiteTensor* tensor, std::vector<T>* data) {
+ AssertAndReturnTypedTensor<T>(tensor);
+ const T* results = GetTensorData<T>(tensor);
+ size_t num = tensor->bytes / sizeof(tensor->type);
+ data->reserve(num);
+ for (int i = 0; i < num; i++) {
+ data->emplace_back(results[i]);
+ }
+}
+
+template <>
+inline void PopulateVector<std::string>(const TfLiteTensor* tensor,
+ std::vector<std::string>* data) {
+ AssertAndReturnTypedTensor<std::string>(tensor);
+ int num = GetStringCount(tensor);
+ data->reserve(num);
+ for (int i = 0; i < num; i++) {
+ const auto& strref = tflite::GetString(tensor, i);
+ data->emplace_back(strref.str, strref.len);
+ }
+}
+
+// Returns the reversely sorted indices of a vector.
+template <typename T>
+std::vector<size_t> ReverseSortIndices(const std::vector<T>& v) {
+ std::vector<size_t> idx(v.size());
+ std::iota(idx.begin(), idx.end(), 0);
+
+ std::stable_sort(idx.begin(), idx.end(),
+ [&v](size_t i1, size_t i2) { return v[i2] < v[i1]; });
+
+ return idx;
+}
+
+// Returns the original (dequantized) value of the 'index'-th element of
+// 'tensor.
+double Dequantize(const TfLiteTensor& tensor, int index);
+
+// Returns the index-th string from the tensor.
+std::string GetStringAtIndex(const TfLiteTensor* labels, int index);
+
+// Loads binary content of a file into a string.
+std::string LoadBinaryContent(const char* filename);
+
+// Gets the tensor from a vector of tensors with name specified inside metadata.
+template <typename TensorType>
+static TensorType* FindTensorByName(
+ const std::vector<TensorType*>& tensors,
+ const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
+ tensor_metadatas,
+ const std::string& name) {
+ if (tensor_metadatas == nullptr ||
+ tensor_metadatas->size() != tensors.size()) {
+ return nullptr;
+ }
+ for (int i = 0; i < tensor_metadatas->size(); i++) {
+ if (strcmp(name.data(), tensor_metadatas->Get(i)->name()->c_str()) == 0) {
+ return tensors[i];
+ }
+ }
+ return nullptr;
+}
+
+} // namespace core
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_
diff --git a/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/tensorflow_lite_support/cc/task/core/tflite_engine.cc
new file mode 100644
index 00000000..cf923f6a
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/tflite_engine.cc
@@ -0,0 +1,297 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+
+#include <unistd.h>
+
+#include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/lite/builtin_ops.h"
+#include "tensorflow/lite/stderr_reporter.h"
+#include "tensorflow/lite/tools/verifier.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+
+#if TFLITE_USE_C_API
+#include "tensorflow/lite/c/c_api_experimental.h"
+#else
+#include "tensorflow/lite/kernels/register.h"
+#endif
+
+namespace tflite {
+namespace task {
+namespace core {
+
+#ifdef __ANDROID__
+// https://github.com/opencv/opencv/issues/14906
+// "ios_base::Init" object is not a part of Android's "iostream" header (in case
+// of clang toolchain, NDK 20).
+//
+// Ref1:
+// https://en.cppreference.com/w/cpp/io/ios_base/Init
+// The header <iostream> behaves as if it defines (directly or indirectly)
+// an instance of std::ios_base::Init with static storage duration
+//
+// Ref2:
+// https://github.com/gcc-mirror/gcc/blob/gcc-8-branch/libstdc%2B%2B-v3/include/std/iostream#L73-L74
+static std::ios_base::Init s_iostream_initializer;
+#endif
+
+using ::absl::StatusCode;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::TfLiteSupportStatus;
+
+int TfLiteEngine::ErrorReporter::Report(const char* format, va_list args) {
+ return std::vsnprintf(error_message, sizeof(error_message), format, args);
+}
+
+bool TfLiteEngine::Verifier::Verify(const char* data, int length,
+ tflite::ErrorReporter* reporter) {
+ return tflite::Verify(data, length, *op_resolver_, reporter);
+}
+
+#if TFLITE_USE_C_API
+TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
+ : model_(nullptr, TfLiteModelDelete),
+ resolver_(std::move(resolver)),
+ verifier_(resolver_.get()) {}
+#else
+TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
+ : model_(), resolver_(std::move(resolver)), verifier_(resolver_.get()) {}
+#endif
+
+std::vector<TfLiteTensor*> TfLiteEngine::GetInputs() {
+ Interpreter* interpreter = this->interpreter();
+ std::vector<TfLiteTensor*> tensors;
+ int input_count = InputCount(interpreter);
+ tensors.reserve(input_count);
+ for (int index = 0; index < input_count; index++) {
+ tensors.push_back(GetInput(interpreter, index));
+ }
+ return tensors;
+}
+
+std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() {
+ Interpreter* interpreter = this->interpreter();
+ std::vector<const TfLiteTensor*> tensors;
+ int output_count = OutputCount(interpreter);
+ tensors.reserve(output_count);
+ for (int index = 0; index < output_count; index++) {
+ tensors.push_back(GetOutput(interpreter, index));
+ }
+ return tensors;
+}
+
+// The following function is adapted from the code in
+// tflite::FlatBufferModel::VerifyAndBuildFromBuffer.
+void TfLiteEngine::VerifyAndBuildModelFromBuffer(const char* buffer_data,
+ size_t buffer_size) {
+#if TFLITE_USE_C_API
+ // First verify with the base flatbuffers verifier.
+ // This verifies that the model is a valid flatbuffer model.
+ flatbuffers::Verifier base_verifier(
+ reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
+ if (!VerifyModelBuffer(base_verifier)) {
+ TF_LITE_REPORT_ERROR(&error_reporter_,
+ "The model is not a valid Flatbuffer buffer");
+ model_ = nullptr;
+ return;
+ }
+ // Next verify with the extra verifier. This verifies that the model only
+ // uses operators supported by the OpResolver.
+ if (!verifier_.Verify(buffer_data, buffer_size, &error_reporter_)) {
+ model_ = nullptr;
+ return;
+ }
+ // Build the model.
+ model_.reset(TfLiteModelCreate(buffer_data, buffer_size));
+#else
+ model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
+ buffer_data, buffer_size, &verifier_, &error_reporter_);
+#endif
+}
+
+absl::Status TfLiteEngine::InitializeFromModelFileHandler() {
+ const char* buffer_data = model_file_handler_->GetFileContent().data();
+ size_t buffer_size = model_file_handler_->GetFileContent().size();
+ VerifyAndBuildModelFromBuffer(buffer_data, buffer_size);
+ if (model_ == nullptr) {
+ // To be replaced with a proper switch-case when TF Lite model builder
+ // returns a `TfLiteStatus` code capturing this type of error.
+ if (absl::StrContains(error_reporter_.error_message,
+ "The model is not a valid Flatbuffer")) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument, error_reporter_.error_message,
+ TfLiteSupportStatus::kInvalidFlatBufferError);
+ } else {
+ // TODO(b/154917059): augment status with another `TfLiteStatus` code when
+ // ready. And use a new `TfLiteStatus::kCoreTfLiteError` for the TFLS
+ // code, instead of the unspecified `kError`.
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown,
+ absl::StrCat(
+ "Could not build model from the provided pre-loaded flatbuffer: ",
+ error_reporter_.error_message));
+ }
+ }
+
+ ASSIGN_OR_RETURN(
+ model_metadata_extractor_,
+ tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer(
+ buffer_data, buffer_size));
+
+ return absl::OkStatus();
+}
+
+absl::Status TfLiteEngine::BuildModelFromFlatBuffer(const char* buffer_data,
+ size_t buffer_size) {
+ if (model_) {
+ return CreateStatusWithPayload(StatusCode::kInternal,
+ "Model already built");
+ }
+ external_file_.set_file_content(std::string(buffer_data, buffer_size));
+ ASSIGN_OR_RETURN(
+ model_file_handler_,
+ ExternalFileHandler::CreateFromExternalFile(&external_file_));
+ return InitializeFromModelFileHandler();
+}
+
+absl::Status TfLiteEngine::BuildModelFromFile(const std::string& file_name) {
+ if (model_) {
+ return CreateStatusWithPayload(StatusCode::kInternal,
+ "Model already built");
+ }
+ external_file_.set_file_name(file_name);
+ ASSIGN_OR_RETURN(
+ model_file_handler_,
+ ExternalFileHandler::CreateFromExternalFile(&external_file_));
+ return InitializeFromModelFileHandler();
+}
+
+absl::Status TfLiteEngine::BuildModelFromFileDescriptor(int file_descriptor) {
+ if (model_) {
+ return CreateStatusWithPayload(StatusCode::kInternal,
+ "Model already built");
+ }
+ external_file_.mutable_file_descriptor_meta()->set_fd(file_descriptor);
+ ASSIGN_OR_RETURN(
+ model_file_handler_,
+ ExternalFileHandler::CreateFromExternalFile(&external_file_));
+ return InitializeFromModelFileHandler();
+}
+
+absl::Status TfLiteEngine::BuildModelFromExternalFileProto(
+ const ExternalFile* external_file) {
+ if (model_) {
+ return CreateStatusWithPayload(StatusCode::kInternal,
+ "Model already built");
+ }
+ ASSIGN_OR_RETURN(model_file_handler_,
+ ExternalFileHandler::CreateFromExternalFile(external_file));
+ return InitializeFromModelFileHandler();
+}
+
+absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
+ tflite::proto::ComputeSettings compute_settings;
+ return InitInterpreter(compute_settings, num_threads);
+}
+
+#if TFLITE_USE_C_API
+const TfLiteRegistration* FindBuiltinOp(void* user_data,
+ TfLiteBuiltinOperator builtin_op,
+ int version) {
+ OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
+ tflite::BuiltinOperator op = static_cast<tflite::BuiltinOperator>(builtin_op);
+ return op_resolver->FindOp(op, version);
+}
+
+const TfLiteRegistration* FindCustomOp(void* user_data, const char* custom_op,
+ int version) {
+ OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
+ return op_resolver->FindOp(custom_op, version);
+}
+#endif
+
+absl::Status TfLiteEngine::InitInterpreter(
+ const tflite::proto::ComputeSettings& compute_settings, int num_threads) {
+ if (model_ == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ "TF Lite FlatBufferModel is null. Please make sure to call one of the "
+ "BuildModelFrom methods before calling InitInterpreter.");
+ }
+#if TFLITE_USE_C_API
+ std::function<absl::Status(TfLiteDelegate*,
+ std::unique_ptr<Interpreter, InterpreterDeleter>*)>
+ initializer = [this, num_threads](
+ TfLiteDelegate* optional_delegate,
+ std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
+ -> absl::Status {
+ std::unique_ptr<TfLiteInterpreterOptions,
+ void (*)(TfLiteInterpreterOptions*)>
+ options{TfLiteInterpreterOptionsCreate(),
+ TfLiteInterpreterOptionsDelete};
+ TfLiteInterpreterOptionsSetOpResolver(options.get(), FindBuiltinOp,
+ FindCustomOp, resolver_.get());
+ TfLiteInterpreterOptionsSetNumThreads(options.get(), num_threads);
+ if (optional_delegate != nullptr) {
+ TfLiteInterpreterOptionsAddDelegate(options.get(), optional_delegate);
+ }
+ interpreter_out->reset(
+ TfLiteInterpreterCreateWithSelectedOps(model_.get(), options.get()));
+ if (*interpreter_out == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kAborted,
+ absl::StrCat("Could not build the TF Lite interpreter: "
+ "TfLiteInterpreterCreateWithSelectedOps failed: ",
+ error_reporter_.error_message));
+ }
+ return absl::OkStatus();
+ };
+#else
+ auto initializer =
+ [this, num_threads](
+ std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
+ -> absl::Status {
+ if (tflite::InterpreterBuilder(*model_, *resolver_)(
+ interpreter_out, num_threads) != kTfLiteOk) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown,
+ absl::StrCat("Could not build the TF Lite interpreter: ",
+ error_reporter_.error_message));
+ }
+ if (*interpreter_out == nullptr) {
+ return CreateStatusWithPayload(StatusCode::kInternal,
+ "TF Lite interpreter is null.");
+ }
+ return absl::OkStatus();
+ };
+#endif
+
+ absl::Status status =
+ interpreter_.InitializeWithFallback(initializer, compute_settings);
+
+ if (!status.ok() &&
+ !status.GetPayload(tflite::support::kTfLiteSupportPayload).has_value()) {
+ status = CreateStatusWithPayload(status.code(), status.message());
+ }
+ return status;
+}
+
+} // namespace core
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/core/tflite_engine.h b/tensorflow_lite_support/cc/task/core/tflite_engine.h
new file mode 100644
index 00000000..30f239da
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/tflite_engine.h
@@ -0,0 +1,245 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
+
+#include <sys/mman.h>
+
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow_lite_support/cc/port/tflite_wrapper.h"
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+
+// If compiled with -DTFLITE_USE_C_API, this file will use the TF Lite C API
+// rather than the TF Lite C++ API.
+// TODO(b/168025296): eliminate the '#if TFLITE_USE_C_API' directives here and
+// elsewhere and instead use the C API unconditionally, once we have a suitable
+// replacement for the features of tflite::support::TfLiteInterpreterWrapper.
+#if TFLITE_USE_C_API
+#include "tensorflow/lite/c/c_api.h"
+#include "tensorflow/lite/core/api/verifier.h"
+#include "tensorflow/lite/tools/verifier.h"
+#else
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/model.h"
+#endif
+
+namespace tflite {
+namespace task {
+namespace core {
+
+// TfLiteEngine encapsulates logic for TFLite model initialization, inference
+// and error reporting.
+class TfLiteEngine {
+ public:
+ // Types.
+ using InterpreterWrapper = tflite::support::TfLiteInterpreterWrapper;
+#if TFLITE_USE_C_API
+ using Model = struct TfLiteModel;
+ using Interpreter = struct TfLiteInterpreter;
+ using ModelDeleter = void (*)(Model*);
+ using InterpreterDeleter = InterpreterWrapper::InterpreterDeleter;
+#else
+ using Model = tflite::FlatBufferModel;
+ using Interpreter = tflite::Interpreter;
+ using ModelDeleter = std::default_delete<Model>;
+ using InterpreterDeleter = std::default_delete<Interpreter>;
+#endif
+
+ // Constructors.
+ explicit TfLiteEngine(
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+ // Model is neither copyable nor movable.
+ TfLiteEngine(const TfLiteEngine&) = delete;
+ TfLiteEngine& operator=(const TfLiteEngine&) = delete;
+
+ // Accessors.
+ static int32_t InputCount(const Interpreter* interpreter) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetInputTensorCount(interpreter);
+#else
+ return interpreter->inputs().size();
+#endif
+ }
+ static int32_t OutputCount(const Interpreter* interpreter) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetOutputTensorCount(interpreter);
+#else
+ return interpreter->outputs().size();
+#endif
+ }
+ static TfLiteTensor* GetInput(Interpreter* interpreter, int index) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetInputTensor(interpreter, index);
+#else
+ return interpreter->tensor(interpreter->inputs()[index]);
+#endif
+ }
+ // Same as above, but const.
+ static const TfLiteTensor* GetInput(const Interpreter* interpreter,
+ int index) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetInputTensor(interpreter, index);
+#else
+ return interpreter->tensor(interpreter->inputs()[index]);
+#endif
+ }
+ static TfLiteTensor* GetOutput(Interpreter* interpreter, int index) {
+#if TFLITE_USE_C_API
+ // We need a const_cast here, because the TF Lite C API only has a non-const
+ // version of GetOutputTensor (in part because C doesn't support overloading
+ // on const).
+ return const_cast<TfLiteTensor*>(
+ TfLiteInterpreterGetOutputTensor(interpreter, index));
+#else
+ return interpreter->tensor(interpreter->outputs()[index]);
+#endif
+ }
+ // Same as above, but const.
+ static const TfLiteTensor* GetOutput(const Interpreter* interpreter,
+ int index) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetOutputTensor(interpreter, index);
+#else
+ return interpreter->tensor(interpreter->outputs()[index]);
+#endif
+ }
+
+ std::vector<TfLiteTensor*> GetInputs();
+ std::vector<const TfLiteTensor*> GetOutputs();
+
+ const Model* model() const { return model_.get(); }
+ Interpreter* interpreter() { return interpreter_.get(); }
+ const Interpreter* interpreter() const { return interpreter_.get(); }
+ InterpreterWrapper* interpreter_wrapper() { return &interpreter_; }
+ const tflite::metadata::ModelMetadataExtractor* metadata_extractor() const {
+ return model_metadata_extractor_.get();
+ }
+
+ // Builds the TF Lite FlatBufferModel (model_) from the raw FlatBuffer data
+ // whose ownership remains with the caller, and which must outlive the current
+ // object. This performs extra verification on the input data using
+ // tflite::Verify.
+ absl::Status BuildModelFromFlatBuffer(const char* buffer_data,
+ size_t buffer_size);
+
+ // Builds the TF Lite model from a given file.
+ absl::Status BuildModelFromFile(const std::string& file_name);
+
+ // Builds the TF Lite model from a given file descriptor using mmap(2).
+ absl::Status BuildModelFromFileDescriptor(int file_descriptor);
+
+ // Builds the TFLite model from the provided ExternalFile proto, which must
+ // outlive the current object.
+ absl::Status BuildModelFromExternalFileProto(
+ const ExternalFile* external_file);
+
+ // Initializes interpreter with encapsulated model.
+ // Note: setting num_threads to -1 has for effect to let TFLite runtime set
+ // the value.
+ absl::Status InitInterpreter(int num_threads = 1);
+
+ // Same as above, but allows specifying `compute_settings` for acceleration.
+ absl::Status InitInterpreter(
+ const tflite::proto::ComputeSettings& compute_settings,
+ int num_threads = 1);
+
+ // Cancels the on-going `Invoke()` call if any and if possible. This method
+ // can be called from a different thread than the one where `Invoke()` is
+ // running.
+ void Cancel() {
+#if TFLITE_USE_C_API
+ // NOP.
+#else
+ interpreter_.Cancel();
+#endif
+ }
+
+ protected:
+ // TF Lite's DefaultErrorReporter() outputs to stderr. This one captures the
+ // error into a string so that it can be used to complement tensorflow::Status
+ // error messages.
+ struct ErrorReporter : public tflite::ErrorReporter {
+ // Last error message captured by this error reporter.
+ char error_message[256];
+ int Report(const char* format, va_list args) override;
+ };
+ // Custom error reporter capturing low-level TF Lite error messages.
+ ErrorReporter error_reporter_;
+
+ private:
+ // Direct wrapper around tflite::TfLiteVerifier which checks the integrity of
+ // the FlatBuffer data provided as input.
+ class Verifier : public tflite::TfLiteVerifier {
+ public:
+ explicit Verifier(const tflite::OpResolver* op_resolver)
+ : op_resolver_(op_resolver) {}
+ bool Verify(const char* data, int length,
+ tflite::ErrorReporter* reporter) override;
+ // The OpResolver to be used to build the TF Lite interpreter.
+ const tflite::OpResolver* op_resolver_;
+ };
+
+ // Verifies that the supplied buffer refers to a valid flatbuffer model,
+ // and that it uses only operators that are supported by the OpResolver
+ // that was passed to the TfLiteEngine constructor, and then builds
+ // the model from the buffer and stores it in 'model_'.
+ void VerifyAndBuildModelFromBuffer(const char* buffer_data,
+ size_t buffer_size);
+
+ // Gets the buffer from the file handler; verifies and builds the model
+ // from the buffer; if successful, sets 'model_metadata_extractor_' to be
+ // a TF Lite Metadata extractor for the model; and calculates an appropriate
+ // return Status,
+ absl::Status InitializeFromModelFileHandler();
+
+ // TF Lite model and interpreter for actual inference.
+ std::unique_ptr<Model, ModelDeleter> model_;
+
+ // Interpreter wrapper built from the model.
+ InterpreterWrapper interpreter_;
+
+ // TFLite Metadata extractor built from the model.
+ std::unique_ptr<tflite::metadata::ModelMetadataExtractor>
+ model_metadata_extractor_;
+
+ // Mechanism used by TF Lite to map Ops referenced in the FlatBuffer model to
+ // actual implementation. Defaults to TF Lite BuiltinOpResolver.
+ std::unique_ptr<tflite::OpResolver> resolver_;
+
+ // Extra verifier for FlatBuffer input data.
+ Verifier verifier_;
+
+ // ExternalFile and corresponding ExternalFileHandler for models loaded from
+ // disk or file descriptor.
+ ExternalFile external_file_;
+ std::unique_ptr<ExternalFileHandler> model_file_handler_;
+};
+
+} // namespace core
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD b/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD
new file mode 100644
index 00000000..33b6f6a6
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD
@@ -0,0 +1,118 @@
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "bert_nl_classifier_c_api.h",
+ "nl_classifier_c_api.h",
+ "nl_classifier_c_api_common.h",
+])
+
+cc_library(
+ name = "nl_classifier",
+ srcs = [
+ "nl_classifier.cc",
+ ],
+ hdrs = [
+ "nl_classifier.h",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:base_task_api",
+ "//tensorflow_lite_support/cc/task/core:category",
+ "//tensorflow_lite_support/cc/task/core:task_api_factory",
+ "//tensorflow_lite_support/cc/task/core:task_utils",
+ "//tensorflow_lite_support/cc/text/tokenizers:regex_tokenizer",
+ "//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
+ "//tensorflow_lite_support/cc/utils:common_utils",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:string",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/core/api",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
+ ],
+)
+
+cc_library(
+ name = "nl_classifier_c_api",
+ srcs = [
+ "nl_classifier_c_api.cc",
+ ],
+ hdrs = [
+ "nl_classifier_c_api.h",
+ "nl_classifier_c_api_common.h",
+ ],
+ visibility = ["//tensorflow_lite_support:__subpackages__"],
+ deps = [
+ ":nl_classifier",
+ ":nl_classifier_c_api_common",
+ "//tensorflow_lite_support/cc/task/core:category",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "bert_nl_classifier",
+ srcs = [
+ "bert_nl_classifier.cc",
+ ],
+ hdrs = [
+ "bert_nl_classifier.h",
+ ],
+ deps = [
+ ":nl_classifier",
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:category",
+ "//tensorflow_lite_support/cc/task/core:task_api_factory",
+ "//tensorflow_lite_support/cc/task/core:task_utils",
+ "//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
+ "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@org_tensorflow//tensorflow/lite:string",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/core/api",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ ],
+)
+
+cc_library(
+ name = "bert_nl_classifier_c_api",
+ srcs = [
+ "bert_nl_classifier_c_api.cc",
+ ],
+ hdrs = [
+ "bert_nl_classifier_c_api.h",
+ "nl_classifier_c_api_common.h",
+ ],
+ visibility = ["//tensorflow_lite_support:__subpackages__"],
+ deps = [
+ ":bert_nl_classifier",
+ ":nl_classifier_c_api_common",
+ "//tensorflow_lite_support/cc/task/core:category",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "nl_classifier_c_api_common",
+ srcs = [
+ "nl_classifier_c_api_common.cc",
+ ],
+ hdrs = [
+ "nl_classifier_c_api_common.h",
+ ],
+ visibility = ["//tensorflow_lite_support:__subpackages__"],
+)
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
new file mode 100644
index 00000000..d689c9e8
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
@@ -0,0 +1,198 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow/lite/string_type.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/task/core/category.h"
+#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
+#include "tensorflow_lite_support/cc/task/core/task_utils.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace nlclassifier {
+
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
+using ::tflite::support::text::tokenizer::TokenizerResult;
+using ::tflite::task::core::FindTensorByName;
+using ::tflite::task::core::PopulateTensor;
+
+namespace {
+constexpr char kIdsTensorName[] = "ids";
+constexpr char kMaskTensorName[] = "mask";
+constexpr char kSegmentIdsTensorName[] = "segment_ids";
+constexpr char kScoreTensorName[] = "probability";
+constexpr char kClassificationToken[] = "[CLS]";
+constexpr char kSeparator[] = "[SEP]";
+constexpr int kTokenizerProcessUnitIndex = 0;
+} // namespace
+
+absl::Status BertNLClassifier::Preprocess(
+ const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
+ auto* input_tensor_metadatas =
+ GetMetadataExtractor()->GetInputTensorMetadata();
+ auto* ids_tensor =
+ FindTensorByName(input_tensors, input_tensor_metadatas, kIdsTensorName);
+ auto* mask_tensor =
+ FindTensorByName(input_tensors, input_tensor_metadatas, kMaskTensorName);
+ auto* segment_ids_tensor = FindTensorByName(
+ input_tensors, input_tensor_metadatas, kSegmentIdsTensorName);
+
+ std::string processed_input = input;
+ absl::AsciiStrToLower(&processed_input);
+
+ TokenizerResult input_tokenize_results;
+ input_tokenize_results = tokenizer_->Tokenize(processed_input);
+
+ // 2 accounts for [CLS], [SEP]
+ absl::Span<const std::string> query_tokens =
+ absl::MakeSpan(input_tokenize_results.subwords.data(),
+ input_tokenize_results.subwords.data() +
+ std::min(static_cast<size_t>(kMaxSeqLen - 2),
+ input_tokenize_results.subwords.size()));
+
+ std::vector<std::string> tokens;
+ tokens.reserve(2 + query_tokens.size());
+ // Start of generating the features.
+ tokens.push_back(kClassificationToken);
+ // For query input.
+ for (const auto& query_token : query_tokens) {
+ tokens.push_back(query_token);
+ }
+ // For Separation.
+ tokens.push_back(kSeparator);
+
+ std::vector<int> input_ids(kMaxSeqLen, 0);
+ std::vector<int> input_mask(kMaxSeqLen, 0);
+ // Convert tokens back into ids and set mask
+ for (int i = 0; i < tokens.size(); ++i) {
+ tokenizer_->LookupId(tokens[i], &input_ids[i]);
+ input_mask[i] = 1;
+ }
+ // |<-----------kMaxSeqLen---------->|
+ // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
+ // input_masks 1 1 1... 1 1 0 0... 0
+ // segment_ids 0 0 0... 0 0 0 0... 0
+
+ PopulateTensor(input_ids, ids_tensor);
+ PopulateTensor(input_mask, mask_tensor);
+ PopulateTensor(std::vector<int>(kMaxSeqLen, 0), segment_ids_tensor);
+
+ return absl::OkStatus();
+}
+
+StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const std::string& /*input*/) {
+ if (output_tensors.size() != 1) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ absl::StrFormat("BertNLClassifier models are expected to have only 1 "
+ "output, found %d",
+ output_tensors.size()),
+ TfLiteSupportStatus::kInvalidNumOutputTensorsError);
+ }
+ const TfLiteTensor* scores = FindTensorByName(
+ output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
+ kScoreTensorName);
+
+ // optional labels extracted from metadata
+ return BuildResults(scores, /*labels=*/nullptr);
+}
+
+StatusOr<std::unique_ptr<BertNLClassifier>>
+BertNLClassifier::CreateFromFile(
+ const std::string& path_to_model_with_metadata,
+ std::unique_ptr<tflite::OpResolver> resolver) {
+ std::unique_ptr<BertNLClassifier> bert_nl_classifier;
+ ASSIGN_OR_RETURN(bert_nl_classifier,
+ core::TaskAPIFactory::CreateFromFile<BertNLClassifier>(
+ path_to_model_with_metadata, std::move(resolver)));
+ RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
+ return std::move(bert_nl_classifier);
+}
+
+StatusOr<std::unique_ptr<BertNLClassifier>>
+BertNLClassifier::CreateFromBuffer(
+ const char* model_with_metadata_buffer_data,
+ size_t model_with_metadata_buffer_size,
+ std::unique_ptr<tflite::OpResolver> resolver) {
+ std::unique_ptr<BertNLClassifier> bert_nl_classifier;
+ ASSIGN_OR_RETURN(bert_nl_classifier,
+ core::TaskAPIFactory::CreateFromBuffer<BertNLClassifier>(
+ model_with_metadata_buffer_data,
+ model_with_metadata_buffer_size, std::move(resolver)));
+ RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
+ return std::move(bert_nl_classifier);
+}
+
+StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromFd(
+ int fd, std::unique_ptr<tflite::OpResolver> resolver) {
+ std::unique_ptr<BertNLClassifier> bert_nl_classifier;
+ ASSIGN_OR_RETURN(
+ bert_nl_classifier,
+ core::TaskAPIFactory::CreateFromFileDescriptor<BertNLClassifier>(
+ fd, std::move(resolver)));
+ RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
+ return std::move(bert_nl_classifier);
+}
+
+absl::Status BertNLClassifier::InitializeFromMetadata() {
+ // Set up mandatory tokenizer.
+ const ProcessUnit* tokenizer_process_unit =
+ GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
+ if (tokenizer_process_unit == nullptr) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "No input process unit found from metadata.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+ ASSIGN_OR_RETURN(tokenizer_,
+ CreateTokenizerFromProcessUnit(tokenizer_process_unit,
+ GetMetadataExtractor()));
+
+ // Set up optional label vector.
+ TrySetLabelFromMetadata(
+ GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex))
+ .IgnoreError();
+ return absl::OkStatus();
+}
+
+} // namespace nlclassifier
+} // namespace text
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h
new file mode 100644
index 00000000..0c709ee0
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h
@@ -0,0 +1,105 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_
+
+#include <stddef.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/string_type.h"
+#include "tensorflow_lite_support/cc/task/core/category.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace nlclassifier {
+
+// Classifier API for NLClassification tasks with Bert models, categorizes
+// string into different classes.
+//
+// The API expects a Bert based TFLite model with metadata populated.
+// The metadata should contain the following information:
+// - input_process_units for Wordpiece/Sentencepiece Tokenizer
+// - 3 input tensors with names "ids", "mask" and "segment_ids"
+// - 1 output tensor of type float32[1, 2], with a optionally attached label
+// file. If a label file is attached, the file should be a plain text file
+// with one label per line, the number of labels should match the number of
+// categories the model outputs.
+
+class BertNLClassifier : public NLClassifier {
+ public:
+ using NLClassifier::NLClassifier;
+ // Max number of tokens to pass to the model.
+ static constexpr int kMaxSeqLen = 128;
+
+ // Factory function to create a BertNLClassifier from TFLite model with
+ // metadata.
+ static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>>
+ CreateFromFile(
+ const std::string& path_to_model_with_metadata,
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+
+ // Factory function to create a BertNLClassifier from in memory buffer of a
+ // TFLite model with metadata.
+ static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>>
+ CreateFromBuffer(
+ const char* model_with_metadata_buffer_data,
+ size_t model_with_metadata_buffer_size,
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+
+ // Factory function to create a BertNLClassifier from the file descriptor of a
+ // TFLite model with metadata.
+ static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>>
+ CreateFromFd(
+ int fd, std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+
+ protected:
+ // Run tokenization on input text and construct three input tensors ids, mask
+ // and segment_ids for the model input.
+ absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
+ const std::string& input) override;
+
+ // Extract model output and create results with label file attached in
+ // metadata. If no label file is attached, use output score index as labels.
+ tflite::support::StatusOr<std::vector<core::Category>> Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const std::string& input) override;
+
+ private:
+ // Initialize the API with the tokenizer and label files set in the metadata.
+ absl::Status InitializeFromMetadata();
+
+ std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
+};
+
+} // namespace nlclassifier
+} // namespace text
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc
new file mode 100644
index 00000000..0decc497
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc
@@ -0,0 +1,70 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h"
+
+#include <memory>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow_lite_support/cc/task/core/category.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
+
+using CategoryCPP = ::tflite::task::core::Category;
+using BertNLClassifierCPP =
+ ::tflite::task::text::nlclassifier::BertNLClassifier;
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+struct BertNLClassifier {
+ std::unique_ptr<BertNLClassifierCPP> impl;
+};
+
+BertNLClassifier* BertNLClassifierFromFile(const char* model_path) {
+ auto classifier_status =
+ BertNLClassifierCPP::CreateFromFile(std::string(model_path));
+ if (classifier_status.ok()) {
+ return new BertNLClassifier{.impl = std::unique_ptr<BertNLClassifierCPP>(
+ dynamic_cast<BertNLClassifierCPP*>(
+ classifier_status.value().release()))};
+ } else {
+ return nullptr;
+ }
+}
+
+Categories* BertNLClassifierClassify(const BertNLClassifier* classifier,
+ const char* text) {
+ std::vector<CategoryCPP> results =
+ classifier->impl->Classify(absl::string_view(text).data());
+ size_t size = results.size();
+ auto* categories = new Category[size];
+
+ for (size_t i = 0; i < size; ++i) {
+ categories[i].text = strdup(results[i].class_name.c_str());
+ categories[i].score = results[i].score;
+ }
+
+ auto* c_categories = new Categories;
+ c_categories->size = size;
+ c_categories->categories = categories;
+ return c_categories;
+}
+
+void BertNLClassifierDelete(BertNLClassifier* classifier) { delete classifier; }
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h
new file mode 100644
index 00000000..1d0b8b67
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h
@@ -0,0 +1,60 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_C_API_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_C_API_H_
+
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h"
+// --------------------------------------------------------------------------
+/// C API for BertNLClassifier.
+///
+/// The API leans towards simplicity and uniformity instead of convenience, as
+/// most usage will be by language-specific wrappers. It provides largely the
+/// same set of functionality as that of the C++ TensorFlow Lite
+/// `BertNLClassifier` API, but is useful for shared libraries where having
+/// a stable ABI boundary is important.
+///
+/// Usage:
+/// <pre><code>
+/// // Create the model and interpreter options.
+/// BertNLClassifier* classifier =
+/// BertNLClassifierFromFile("/path/to/model.tflite");
+///
+/// // classification.
+/// Categories* categories = Classify(classifier, context, question);
+///
+/// // Dispose of the API object.
+/// BertNLClassifierrDelete(classifier);
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef struct BertNLClassifier BertNLClassifier;
+
+// Creates BertNLClassifier from model path, returns nullptr if the file
+// doesn't exist or is not a well formatted TFLite model path.
+extern BertNLClassifier* BertNLClassifierFromFile(const char* model_path);
+
+// Invokes the encapsulated TFLite model and classifies the input text.
+extern struct Categories* BertNLClassifierClassify(
+ const BertNLClassifier* classifier, const char* text);
+
+extern void BertNLClassifierDelete(BertNLClassifier* classifier);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_C_API_H_
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
new file mode 100644
index 00000000..1643e3e0
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
@@ -0,0 +1,467 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
+
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/category.h"
+#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
+#include "tensorflow_lite_support/cc/task/core/task_utils.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+#include "tensorflow_lite_support/cc/utils/common_utils.h"
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace nlclassifier {
+
+using ::absl::StatusCode;
+using ::flatbuffers::Offset;
+using ::flatbuffers::Vector;
+using ::tflite::TensorMetadata;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+using ::tflite::support::text::tokenizer::RegexTokenizer;
+using ::tflite::support::text::tokenizer::Tokenizer;
+using ::tflite::support::text::tokenizer::TokenizerResult;
+using ::tflite::support::utils::LoadVocabFromBuffer;
+using ::tflite::task::core::Category;
+using ::tflite::task::core::Dequantize;
+using ::tflite::task::core::GetStringAtIndex;
+using ::tflite::task::core::PopulateTensor;
+
+namespace {
+constexpr int kRegexTokenizerInputTensorIndex = 0;
+constexpr int kRegexTokenizerProcessUnitIndex = 0;
+
+StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
+ associated_files,
+ const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
+ if (associated_files == nullptr || associated_files->size() < 1 ||
+ associated_files->Get(0)->name() == nullptr) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "Invalid vocab_file from input process unit.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+ ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
+ metadata_extractor->GetAssociatedFile(
+ associated_files->Get(0)->name()->str()));
+ return vocab_buffer;
+}
+
+StatusOr<std::unique_ptr<Tokenizer>> CreateRegexTokenizerFromProcessUnit(
+ const tflite::ProcessUnit* tokenizer_process_unit,
+ const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
+ if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "No metadata or input process unit found.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+
+ if (tokenizer_process_unit->options_type() !=
+ ProcessUnitOptions_RegexTokenizerOptions) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kNotFound,
+ absl::StrCat(
+ "Incorrect options_type:", tokenizer_process_unit->options_type(),
+ " need RegexTokenizerOptions."),
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+
+ const tflite::RegexTokenizerOptions* options =
+ tokenizer_process_unit->options_as<RegexTokenizerOptions>();
+ ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
+ CheckAndLoadFirstAssociatedFile(options->vocab_file(),
+ metadata_extractor));
+ if (options->delim_regex_pattern() == nullptr) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "Invalid delim_regex_pattern from input process unit.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+
+ std::unique_ptr<RegexTokenizer> regex_tokenizer =
+ absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(),
+ vocab_buffer.data(),
+ vocab_buffer.size());
+
+ int unknown_token_id = 0;
+ if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "RegexTokenizer doesn't have <UNKNOWN> token.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+
+ int pad_token_id = 0;
+ if (!regex_tokenizer->GetPadToken(&pad_token_id)) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "RegexTokenizer doesn't have <PAD> token.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+ return regex_tokenizer;
+}
+
+} // namespace
+
+const NLClassifierOptions& NLClassifier::GetOptions() const { return options_; }
+
+absl::Status NLClassifier::TrySetLabelFromMetadata(
+ const TensorMetadata* metadata) {
+ if (metadata == nullptr) {
+ return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
+ "Metadata not found for output tensor",
+ TfLiteSupportStatus::kMetadataNotFoundError);
+ }
+ const auto* associated_files = metadata->associated_files();
+ if (associated_files == nullptr || associated_files->size() == 0) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "No label file found for tensor metadata.",
+ TfLiteSupportStatus::kMetadataMissingLabelsError);
+ }
+ const tflite::AssociatedFile* associated_file =
+ associated_files->Get(kOutputTensorLabelFileIndex);
+ if (associated_file->type() != AssociatedFileType_TENSOR_AXIS_LABELS) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "Incorrect label type found for tensor metadata.",
+ TfLiteSupportStatus::kMetadataMissingLabelsError);
+ }
+ tflite::support::StatusOr<absl::string_view> label_buffer =
+ GetMetadataExtractor()->GetAssociatedFile(
+ associated_files->Get(kOutputTensorIndex)->name()->str());
+ if (label_buffer.ok()) {
+ labels_vector_ =
+ absl::make_unique<std::vector<std::string>>(LoadVocabFromBuffer(
+ label_buffer.value().data(), label_buffer.value().size()));
+ return absl::OkStatus();
+ } else {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "Failed to extract label file from metadata.",
+ TfLiteSupportStatus::kMetadataMissingLabelsError);
+ }
+}
+
+std::vector<Category> NLClassifier::Classify(const std::string& text) {
+ // The NLClassifier implementation for Preprocess() and Postprocess() never
+ // returns errors: just call value().
+ return Infer(text).value();
+}
+
+absl::Status NLClassifier::Preprocess(
+ const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
+ TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
+ input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
+ options_.input_tensor_name, options_.input_tensor_index);
+ if (input_tensor == nullptr) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "No input tensor found from NLClassifierOptions.",
+ TfLiteSupportStatus::kInputTensorNotFoundError);
+ }
+
+ if (HasRegexTokenizerMetadata()) {
+ // |<-------sentence_length-------->|
+ // input_tensor <START>, t1, t2... <PAD>, <PAD>...
+ // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
+ // found in tokenizer vocab.
+ TokenizerResult result = tokenizer_->Tokenize(input);
+
+ size_t max_sentence_length = input_tensor->dims->size == 2
+ ? input_tensor->dims->data[1]
+ : input_tensor->dims->data[0];
+
+ int unknown_token_id = 0;
+ tokenizer_->GetUnknownToken(&unknown_token_id);
+
+ int pad_token_id = 0;
+ tokenizer_->GetPadToken(&pad_token_id);
+
+ std::vector<int> input_tokens(max_sentence_length, pad_token_id);
+ int start_token_id = 0;
+ size_t input_token_index = 0;
+ if (tokenizer_->GetStartToken(&start_token_id)) {
+ input_tokens[0] = start_token_id;
+ input_token_index = 1;
+ }
+
+ for (size_t i = 0; (i < result.subwords.size()) &&
+ (input_token_index < max_sentence_length);
+ ++i, ++input_token_index) {
+ const std::string& token = result.subwords[i];
+ int token_id = 0;
+ if (tokenizer_->LookupId(token, &token_id)) {
+ input_tokens[input_token_index] = token_id;
+ } else {
+ input_tokens[input_token_index] = unknown_token_id;
+ }
+ }
+
+ PopulateTensor(input_tokens, input_tensor);
+ } else {
+ PopulateTensor(input, input_tensor);
+ }
+ return absl::OkStatus();
+}
+
+StatusOr<std::vector<Category>> NLClassifier::Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const std::string& /*input*/) {
+ return BuildResults(
+ FindTensorWithNameOrIndex(
+ output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
+ options_.output_score_tensor_name,
+ options_.output_score_tensor_index),
+ FindTensorWithNameOrIndex(
+ output_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
+ options_.output_label_tensor_name,
+ options_.output_label_tensor_index));
+}
+
+std::vector<Category> NLClassifier::BuildResults(const TfLiteTensor* scores,
+ const TfLiteTensor* labels) {
+ bool use_index_as_labels = (labels_vector_ == nullptr) && (labels == nullptr);
+ // Some models output scores with transposed shape [1, categories]
+ int categories =
+ scores->dims->size == 2 ? scores->dims->data[1] : scores->dims->data[0];
+
+ std::vector<Category> predictions;
+ predictions.reserve(categories);
+
+ bool should_dequantize = scores->type == kTfLiteUInt8 ||
+ scores->type == kTfLiteInt8 ||
+ scores->type == kTfLiteInt16;
+ for (int index = 0; index < categories; index++) {
+ std::string label;
+ if (use_index_as_labels) {
+ label = std::to_string(index);
+ } else if (labels_vector_ == nullptr) {
+ if (labels->type == kTfLiteString) {
+ label = GetStringAtIndex(labels, index);
+ } else if (labels->type == kTfLiteInt32) {
+ label = std::to_string(GetTensorData<int>(labels)[index]);
+ }
+ } else {
+ label = (*labels_vector_)[index];
+ }
+ if (should_dequantize) {
+ predictions.push_back(Category(label, Dequantize(*scores, index)));
+ } else if (scores->type == kTfLiteBool) {
+ predictions.push_back(
+ Category(label, GetTensorData<bool>(scores)[index] ? 1.0 : 0.0));
+ } else {
+ predictions.push_back(
+ Category(label, scores->type == kTfLiteFloat32
+ ? GetTensorData<float>(scores)[index]
+ : GetTensorData<double>(scores)[index]));
+ }
+ }
+
+ return predictions;
+}
+absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
+ options_ = options;
+ // input tensor should be type STRING
+ auto input_tensor = FindTensorWithNameOrIndex(
+ GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(),
+ options.input_tensor_name, options.input_tensor_index);
+ if (input_tensor == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("No input tensor found with name ",
+ options.input_tensor_name, " or at index ",
+ options.input_tensor_index),
+ TfLiteSupportStatus::kInputTensorNotFoundError);
+ }
+ if (HasRegexTokenizerMetadata()) {
+ if (input_tensor->type != kTfLiteInt32) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
+ ". Requested INT32, got ",
+ TfLiteTypeGetName(input_tensor->type), "."),
+ TfLiteSupportStatus::kInvalidInputTensorTypeError);
+ }
+ RETURN_IF_ERROR(SetupRegexTokenizer());
+ } else {
+ if (input_tensor->type != kTfLiteString) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
+ ". Requested STRING, got ",
+ TfLiteTypeGetName(input_tensor->type), "."),
+ TfLiteSupportStatus::kInvalidInputTensorTypeError);
+ }
+ }
+
+ // output score tensor should be type
+ // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
+ std::vector<const TfLiteTensor*> output_tensors = GetOutputTensors();
+ const Vector<Offset<TensorMetadata>>* output_tensor_metadatas =
+ GetMetadataExtractor()->GetOutputTensorMetadata();
+
+ const auto scores = FindTensorWithNameOrIndex(
+ output_tensors, output_tensor_metadatas, options.output_score_tensor_name,
+ options.output_score_tensor_index);
+ if (scores == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("No output score tensor found with name ",
+ options.output_score_tensor_name, " or at index ",
+ options.output_score_tensor_index),
+ TfLiteSupportStatus::kOutputTensorNotFoundError);
+ }
+ static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteInt8,
+ kTfLiteInt16, kTfLiteFloat32,
+ kTfLiteFloat64, kTfLiteBool};
+ if (!absl::c_linear_search(valid_types, scores->type)) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("Type mismatch for score tensor ", scores->name,
+ ". Requested one of these types: "
+ "INT8/UINT8/INT16/FLOAT32/FLOAT64/BOOL, got ",
+ TfLiteTypeGetName(scores->type), "."),
+ TfLiteSupportStatus::kInvalidOutputTensorTypeError);
+ }
+
+ // Extract associated label file from output score tensor if one exists, a
+ // well-formatted metadata should have same number of tensors with the model.
+ if (output_tensor_metadatas &&
+ output_tensor_metadatas->size() == output_tensors.size()) {
+ for (int i = 0; i < output_tensor_metadatas->size(); ++i) {
+ const tflite::TensorMetadata* metadata = output_tensor_metadatas->Get(i);
+ if ((metadata->name() && metadata->name()->string_view() ==
+ options.output_score_tensor_name) ||
+ i == options.output_score_tensor_index) {
+ if (TrySetLabelFromMetadata(metadata).ok()) {
+ return absl::OkStatus();
+ }
+ }
+ }
+ }
+
+ // If labels_vector_ is not set up from metadata, try register output label
+ // tensor from options.
+ if (labels_vector_ == nullptr) {
+ // output label tensor should be type STRING or INT32 if the one exists
+ auto labels = FindTensorWithNameOrIndex(
+ output_tensors, output_tensor_metadatas,
+ options.output_label_tensor_name, options.output_label_tensor_index);
+ if (labels != nullptr && labels->type != kTfLiteString &&
+ labels->type != kTfLiteInt32) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("Type mismatch for label tensor ", scores->name,
+ ". Requested STRING or INT32, got ",
+ TfLiteTypeGetName(scores->type), "."),
+ TfLiteSupportStatus::kInvalidOutputTensorTypeError);
+ }
+ }
+ return absl::OkStatus();
+}
+
+StatusOr<std::unique_ptr<NLClassifier>>
+NLClassifier::CreateFromBufferAndOptions(
+ const char* model_buffer_data, size_t model_buffer_size,
+ const NLClassifierOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver) {
+ std::unique_ptr<NLClassifier> nl_classifier;
+ ASSIGN_OR_RETURN(
+ nl_classifier,
+ core::TaskAPIFactory::CreateFromBuffer<NLClassifier>(
+ model_buffer_data, model_buffer_size, std::move(resolver)));
+ RETURN_IF_ERROR(nl_classifier->Initialize(options));
+ return std::move(nl_classifier);
+}
+
+StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
+ const std::string& path_to_model, const NLClassifierOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver) {
+ std::unique_ptr<NLClassifier> nl_classifier;
+ ASSIGN_OR_RETURN(nl_classifier,
+ core::TaskAPIFactory::CreateFromFile<NLClassifier>(
+ path_to_model, std::move(resolver)));
+ RETURN_IF_ERROR(nl_classifier->Initialize(options));
+ return std::move(nl_classifier);
+}
+
+StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
+ int fd, const NLClassifierOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver) {
+ std::unique_ptr<NLClassifier> nl_classifier;
+ ASSIGN_OR_RETURN(nl_classifier,
+ core::TaskAPIFactory::CreateFromFileDescriptor<NLClassifier>(
+ fd, std::move(resolver)));
+ RETURN_IF_ERROR(nl_classifier->Initialize(options));
+ return std::move(nl_classifier);
+}
+
+bool NLClassifier::HasRegexTokenizerMetadata() {
+ const TensorMetadata* input_tensor_metadata =
+ GetMetadataExtractor()->GetInputTensorMetadata(
+ kRegexTokenizerInputTensorIndex);
+ if (input_tensor_metadata == nullptr) {
+ return false;
+ }
+ tflite::support::StatusOr<const tflite::ProcessUnit*> status =
+ GetMetadataExtractor()->FindFirstProcessUnit(
+ *input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions);
+ return status.ok() ? status.value() != nullptr : false;
+}
+
+absl::Status NLClassifier::SetupRegexTokenizer() {
+ ASSIGN_OR_RETURN(
+ std::unique_ptr<Tokenizer> base_tokenizer,
+ CreateRegexTokenizerFromProcessUnit(
+ GetMetadataExtractor()
+ ->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
+ ->process_units()
+ ->Get(kRegexTokenizerProcessUnitIndex),
+ GetMetadataExtractor()));
+
+ tokenizer_ = std::unique_ptr<RegexTokenizer>(
+ dynamic_cast<RegexTokenizer*>(base_tokenizer.release()));
+
+ return absl::OkStatus();
+}
+
+} // namespace nlclassifier
+} // namespace text
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
new file mode 100644
index 00000000..2a9573a1
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
@@ -0,0 +1,181 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_
+
+#include <stddef.h>
+#include <string.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/string_type.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
+#include "tensorflow_lite_support/cc/task/core/category.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace nlclassifier {
+
+// Options to identify input and output tensors of the model
+struct NLClassifierOptions {
+ int input_tensor_index = 0;
+ int output_score_tensor_index = 0;
+ // By default there is no output label tensor. The label file can be attached
+ // to the output score tensor metadata.
+ int output_label_tensor_index = -1;
+ std::string input_tensor_name = "INPUT";
+ std::string output_score_tensor_name = "OUTPUT_SCORE";
+ std::string output_label_tensor_name = "OUTPUT_LABEL";
+};
+
+// Classifier API for NLClassification tasks, categorizes string into different
+// classes.
+//
+// The API expects a TFLite model with the following input/output tensor:
+// Input tensor:
+// (kTfLiteString) - input of the model, accepts a string.
+// or
+// (kTfLiteInt32) - input of the model, accepts a tokenized
+// indices of a string input. A RegexTokenizer needs to be set up in the input
+// tensor's metadata.
+// Output score tensor:
+// (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/
+// kTfLiteFloat64/kTfLiteBool)
+// - output scores for each class, if type is one of the Int types,
+// dequantize it to double, if type is kTfLiteBool, convert the values to
+// 0.0 and 1.0 respectively
+// - can have an optional associated file in metadata for labels, the file
+// should be a plain text file with one label per line, the number of
+// labels should match the number of categories the model outputs.
+// Output label tensor: optional
+// (kTfLiteString/kTfLiteInt32)
+// - output classname for each class, should be of the same length with
+// scores. If this tensor is not present, the API uses score indices as
+// classnames.
+// - will be ignored if output score tensor already has an associated label
+// file.
+//
+// By default the API tries to find the input/output tensors with default
+// configurations in NLClassifierOptions, with tensor name prioritized over
+// tensor index. The option is configurable for different TFLite models.
+class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
+ const std::string&> {
+ public:
+ using BaseTaskApi::BaseTaskApi;
+
+ // Creates a NLClassifier from TFLite model buffer.
+ static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
+ CreateFromBufferAndOptions(
+ const char* model_buffer_data, size_t model_buffer_size,
+ const NLClassifierOptions& options = {},
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+
+ // Creates a NLClassifier from TFLite model file.
+ static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
+ CreateFromFileAndOptions(
+ const std::string& path_to_model, const NLClassifierOptions& options = {},
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+
+ // Creates a NLClassifier from TFLite model file descriptor.
+ static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
+ CreateFromFdAndOptions(
+ int fd, const NLClassifierOptions& options = {},
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+
+ // Performs classification on a string input, returns classified results.
+ std::vector<core::Category> Classify(const std::string& text);
+
+ protected:
+ static constexpr int kOutputTensorIndex = 0;
+ static constexpr int kOutputTensorLabelFileIndex = 0;
+
+ absl::Status Initialize(const NLClassifierOptions& options);
+ const NLClassifierOptions& GetOptions() const;
+
+ // Try to extract attached label file from metadata and initialize
+ // labels_vector_, return error if metadata type is incorrect or no label file
+ // is attached in metadata.
+ absl::Status TrySetLabelFromMetadata(const TensorMetadata* metadata);
+
+ // Pass through the input text into model's input tensor.
+ absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
+ const std::string& input) override;
+
+ // Extract model output and create results with output label tensor or label
+ // file attached in metadata. If no output label tensor or label file is
+ // found, use output score index as labels.
+ tflite::support::StatusOr<std::vector<core::Category>> Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const std::string& input) override;
+
+ std::vector<core::Category> BuildResults(const TfLiteTensor* scores,
+ const TfLiteTensor* labels);
+
+ // Gets the tensor from a vector of tensors by checking tensor name first and
+ // tensor index second, return nullptr if no tensor is found.
+ template <typename TensorType>
+ static TensorType* FindTensorWithNameOrIndex(
+ const std::vector<TensorType*>& tensors,
+ const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
+ metadata_array,
+ const std::string& name, int index) {
+ if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
+ for (int i = 0; i < metadata_array->size(); i++) {
+ if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
+ return tensors[i];
+ }
+ }
+ }
+
+ for (TensorType* tensor : tensors) {
+ if (tensor->name == name) {
+ return tensor;
+ }
+ }
+ return index >= 0 && index < tensors.size() ? tensors[index] : nullptr;
+ }
+
+ private:
+ bool HasRegexTokenizerMetadata();
+ absl::Status SetupRegexTokenizer();
+
+ NLClassifierOptions options_;
+ // labels vector initialized from output tensor's associated file, if one
+ // exists.
+ std::unique_ptr<std::vector<std::string>> labels_vector_;
+ std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_;
+};
+
+} // namespace nlclassifier
+} // namespace text
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc
new file mode 100644
index 00000000..3f7827d8
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc
@@ -0,0 +1,89 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h"
+
+#include <memory>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow_lite_support/cc/task/core/category.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
+
+using CategoryCPP = ::tflite::task::core::Category;
+using NLClassifierCPP = ::tflite::task::text::nlclassifier::NLClassifier;
+using NLClassifierOptionsCPP =
+ ::tflite::task::text::nlclassifier::NLClassifierOptions;
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+struct NLClassifier {
+ std::unique_ptr<NLClassifierCPP> impl;
+};
+
+NLClassifier* NLClassifierFromFileAndOptions(
+ const char* model_path, const NLClassifierOptions* options) {
+ auto classifier_status = NLClassifierCPP::CreateFromFileAndOptions(
+ std::string(model_path),
+ {
+ .input_tensor_index = options->input_tensor_index,
+ .output_score_tensor_index = options->output_score_tensor_index,
+ .output_label_tensor_index = options->output_label_tensor_index,
+ .input_tensor_name = !options->input_tensor_name
+ ? ""
+ : std::string(options->input_tensor_name),
+ .output_score_tensor_name =
+ !options->output_score_tensor_name
+ ? ""
+ : std::string(options->output_score_tensor_name),
+ .output_label_tensor_name =
+ !options->output_label_tensor_name
+ ? ""
+ : std::string(options->output_label_tensor_name),
+ });
+
+ if (classifier_status.ok()) {
+ return new NLClassifier{
+ .impl = std::unique_ptr<NLClassifierCPP>(dynamic_cast<NLClassifierCPP*>(
+ classifier_status.value().release()))};
+ } else {
+ return nullptr;
+ }
+}
+
+Categories* NLClassifierClassify(const NLClassifier* classifier,
+ const char* text) {
+ std::vector<CategoryCPP> results =
+ classifier->impl->Classify(absl::string_view(text).data());
+ size_t size = results.size();
+ auto* categories = new Category[size];
+
+ for (size_t i = 0; i < size; ++i) {
+ categories[i].text = strdup(results[i].class_name.c_str());
+ categories[i].score = results[i].score;
+ }
+
+ auto* c_categories = new Categories;
+ c_categories->size = size;
+ c_categories->categories = categories;
+ return c_categories;
+}
+
+void NLClassifierDelete(NLClassifier* classifier) { delete classifier; }
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h
new file mode 100644
index 00000000..1af93f29
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h
@@ -0,0 +1,72 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_H_
+
+
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h"
+// --------------------------------------------------------------------------
+/// C API for NLClassifier.
+///
+/// The API leans towards simplicity and uniformity instead of convenience, as
+/// most usage will be by language-specific wrappers. It provides largely the
+/// same set of functionality as that of the C++ TensorFlow Lite `NLClassifier`
+/// API, but is useful for shared libraries where having a stable ABI boundary
+/// is important.
+///
+/// Usage:
+/// <pre><code>
+/// // Create the model and interpreter options.
+/// NLClassifier* classifier = NLClassifierFromFileAndOptions(
+/// "/path/to/model.tflite");
+///
+/// // classification.
+/// Categories* categories = Classify(classifier, context, question);
+///
+/// // Dispose of the API object.
+/// NLClassifierDelete(classifier);
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef struct NLClassifier NLClassifier;
+
+struct NLClassifierOptions {
+ int input_tensor_index;
+ int output_score_tensor_index;
+ int output_label_tensor_index;
+ const char* input_tensor_name;
+ const char* output_score_tensor_name;
+ const char* output_label_tensor_name;
+};
+
+// Creates NLClassifier from model path and options, returns nullptr if the file
+// doesn't exist or is not a well formatted TFLite model path.
+extern NLClassifier* NLClassifierFromFileAndOptions(
+ const char* model_path,
+ const struct NLClassifierOptions* options);
+
+// Invokes the encapsulated TFLite model and classifies the input text.
+extern struct Categories* NLClassifierClassify(const NLClassifier* classifier,
+ const char* text);
+
+extern void NLClassifierDelete(NLClassifier* classifier);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_H_
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc
new file mode 100644
index 00000000..3beb658a
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc
@@ -0,0 +1,30 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h"
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+void NLClassifierCategoriesDelete(Categories* categories) {
+ delete[] categories->categories;
+ delete categories;
+}
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h
new file mode 100644
index 00000000..663c873c
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h
@@ -0,0 +1,43 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_
+
+// Common structs shared between NLClassifier APIs
+//
+/// // Dispose of the Categories object.
+/// NLClassifierCategoriesDelete(categories);
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+struct Category {
+ char* text;
+ double score;
+};
+
+struct Categories {
+ int size;
+ struct Category* categories;
+};
+
+extern void NLClassifierCategoriesDelete(struct Categories* categories);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_
diff --git a/tensorflow_lite_support/cc/task/text/qa/BUILD b/tensorflow_lite_support/cc/task/text/qa/BUILD
new file mode 100644
index 00000000..49ad5a1f
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/qa/BUILD
@@ -0,0 +1,61 @@
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "bert_qa_c_api.h",
+])
+
+cc_library(
+ name = "question_answerer",
+ hdrs = [
+ "question_answerer.h",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc/task/core:base_task_api",
+ "//tensorflow_lite_support/cc/task/core:tflite_engine",
+ ],
+)
+
+cc_library(
+ name = "bert_question_answerer",
+ srcs = [
+ "bert_question_answerer.cc",
+ ],
+ hdrs = [
+ "bert_question_answerer.h",
+ ],
+ deps = [
+ ":question_answerer",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:base_task_api",
+ "//tensorflow_lite_support/cc/task/core:task_api_factory",
+ "//tensorflow_lite_support/cc/task/core:task_utils",
+ "//tensorflow_lite_support/cc/task/core:tflite_engine",
+ "//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer",
+ "//tensorflow_lite_support/cc/text/tokenizers:sentencepiece_tokenizer",
+ "//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
+ "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "bert_qa_c_api",
+ srcs = [
+ "bert_qa_c_api.cc",
+ ],
+ hdrs = [
+ "bert_qa_c_api.h",
+ ],
+ visibility = ["//tensorflow_lite_support:__subpackages__"],
+ deps = [
+ ":bert_question_answerer",
+ ":question_answerer",
+ ],
+)
diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc b/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc
new file mode 100644
index 00000000..5fafb59d
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc
@@ -0,0 +1,79 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h"
+
+#include <memory>
+
+#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
+#include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h"
+
+using BertQuestionAnswererCPP = ::tflite::task::text::qa::BertQuestionAnswerer;
+using QaAnswerCPP = ::tflite::task::text::qa::QaAnswer;
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+struct BertQuestionAnswerer {
+ std::unique_ptr<BertQuestionAnswererCPP> impl;
+};
+
+BertQuestionAnswerer* BertQuestionAnswererFromFile(const char* model_path) {
+ auto bert_qa_status =
+ BertQuestionAnswererCPP::CreateFromFile(std::string(model_path));
+ if (bert_qa_status.ok()) {
+ return new BertQuestionAnswerer{
+ .impl = std::unique_ptr<BertQuestionAnswererCPP>(
+ dynamic_cast<BertQuestionAnswererCPP*>(
+ bert_qa_status.value().release()))};
+ } else {
+ return nullptr;
+ }
+}
+
+QaAnswers* BertQuestionAnswererAnswer(
+ const BertQuestionAnswerer* question_answerer, const char* context,
+ const char* question) {
+ std::vector<QaAnswerCPP> answers = question_answerer->impl->Answer(
+ absl::string_view(context).data(), absl::string_view(question).data());
+ size_t size = answers.size();
+ auto* qa_answers = new QaAnswer[size];
+
+ for (size_t i = 0; i < size; ++i) {
+ qa_answers[i].start = answers[i].pos.start;
+ qa_answers[i].end = answers[i].pos.end;
+ qa_answers[i].logit = answers[i].pos.logit;
+ qa_answers[i].text = strdup(answers[i].text.c_str());
+ }
+
+ auto* c_answers = new QaAnswers;
+ c_answers->size = size;
+ c_answers->answers = qa_answers;
+ return c_answers;
+}
+
+void BertQuestionAnswererDelete(BertQuestionAnswerer* bert_question_answerer) {
+ delete bert_question_answerer;
+}
+
+void BertQuestionAnswererQaAnswersDelete(QaAnswers* qa_answers) {
+ delete[] qa_answers->answers;
+ delete qa_answers;
+}
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h b/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h
new file mode 100644
index 00000000..7fd36948
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h
@@ -0,0 +1,78 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QA_C_API_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QA_C_API_H_
+
+// --------------------------------------------------------------------------
+/// C API for BertQuestionAnswerer.
+///
+/// The API leans towards simplicity and uniformity instead of convenience, as
+/// most usage will be by language-specific wrappers. It provides largely the
+/// same set of functionality as that of the C++ TensorFlow Lite
+/// `BertQuestionAnswerer` API, but is useful for shared libraries where having
+/// a stable ABI boundary is important.
+///
+/// Usage:
+/// <pre><code>
+/// // Create the model and interpreter options.
+/// BertQuestionAnswerer* qa_answerer =
+/// BertQuestionAnswererFromFile("/path/to/model.tflite");
+///
+/// // answer a question.
+/// QaAnswers* answers = Answer(qa_answerer, context, question);
+///
+/// // Dispose of the API and QaAnswers objects.
+/// BertQuestionAnswererDelete(qa_answerer);
+/// BertQuestionAnswererQaAnswersDelete(answers);
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef struct BertQuestionAnswerer BertQuestionAnswerer;
+
+struct QaAnswer {
+ int start;
+ int end;
+ float logit;
+ char* text;
+};
+
+struct QaAnswers {
+ int size;
+ struct QaAnswer* answers;
+};
+
+// Creates BertQuestionAnswerer from model path, returns nullptr if the file
+// doesn't exist or is not a well formatted TFLite model path.
+extern BertQuestionAnswerer* BertQuestionAnswererFromFile(
+ const char* model_path);
+
+// Invokes the encapsulated TFLite model and answers a question based on
+// context.
+extern struct QaAnswers* BertQuestionAnswererAnswer(
+ const BertQuestionAnswerer* question_answerer, const char* context,
+ const char* question);
+
+extern void BertQuestionAnswererDelete(
+ BertQuestionAnswerer* bert_question_answerer);
+
+extern void BertQuestionAnswererQaAnswersDelete(struct QaAnswers* qa_answers);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QA_C_API_H_
diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc
new file mode 100644
index 00000000..aa7ffef3
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc
@@ -0,0 +1,393 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
+
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/task/core/task_utils.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace qa {
+
+constexpr char kIdsTensorName[] = "ids";
+constexpr char kMaskTensorName[] = "mask";
+constexpr char kSegmentIdsTensorName[] = "segment_ids";
+constexpr char kEndLogitsTensorName[] = "end_logits";
+constexpr char kStartLogitsTensorName[] = "start_logits";
+
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+using ::tflite::support::text::tokenizer::BertTokenizer;
+using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
+using ::tflite::support::text::tokenizer::SentencePieceTokenizer;
+using ::tflite::support::text::tokenizer::TokenizerResult;
+using ::tflite::task::core::FindTensorByName;
+using ::tflite::task::core::PopulateTensor;
+using ::tflite::task::core::PopulateVector;
+using ::tflite::task::core::ReverseSortIndices;
+
+namespace {
+constexpr int kTokenizerProcessUnitIndex = 0;
+}
+
+StatusOr<std::unique_ptr<QuestionAnswerer>>
+BertQuestionAnswerer::CreateFromFile(
+ const std::string& path_to_model_with_metadata) {
+ std::unique_ptr<BertQuestionAnswerer> api_to_init;
+ ASSIGN_OR_RETURN(
+ api_to_init,
+ core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
+ path_to_model_with_metadata,
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ kNumLiteThreads));
+ RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
+ return api_to_init;
+}
+
+StatusOr<std::unique_ptr<QuestionAnswerer>>
+BertQuestionAnswerer::CreateFromBuffer(
+ const char* model_with_metadata_buffer_data,
+ size_t model_with_metadata_buffer_size) {
+ std::unique_ptr<BertQuestionAnswerer> api_to_init;
+ ASSIGN_OR_RETURN(
+ api_to_init,
+ core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
+ model_with_metadata_buffer_data, model_with_metadata_buffer_size,
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ kNumLiteThreads));
+ RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
+ return api_to_init;
+}
+
+StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd(
+ int fd) {
+ std::unique_ptr<BertQuestionAnswerer> api_to_init;
+ ASSIGN_OR_RETURN(
+ api_to_init,
+ core::TaskAPIFactory::CreateFromFileDescriptor<BertQuestionAnswerer>(
+ fd, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ kNumLiteThreads));
+ RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
+ return api_to_init;
+}
+
+StatusOr<std::unique_ptr<QuestionAnswerer>>
+BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
+ const std::string& path_to_model, const std::string& path_to_vocab) {
+ std::unique_ptr<BertQuestionAnswerer> api_to_init;
+ ASSIGN_OR_RETURN(
+ api_to_init,
+ core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
+ path_to_model,
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ kNumLiteThreads));
+ api_to_init->InitializeBertTokenizer(path_to_vocab);
+ return api_to_init;
+}
+
+StatusOr<std::unique_ptr<QuestionAnswerer>>
+BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
+ const char* model_buffer_data, size_t model_buffer_size,
+ const char* vocab_buffer_data, size_t vocab_buffer_size) {
+ std::unique_ptr<BertQuestionAnswerer> api_to_init;
+ ASSIGN_OR_RETURN(
+ api_to_init,
+ core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
+ model_buffer_data, model_buffer_size,
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ kNumLiteThreads));
+ api_to_init->InitializeBertTokenizerFromBinary(vocab_buffer_data,
+ vocab_buffer_size);
+ return api_to_init;
+}
+
+StatusOr<std::unique_ptr<QuestionAnswerer>>
+BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
+ const std::string& path_to_model, const std::string& path_to_spmodel) {
+ std::unique_ptr<BertQuestionAnswerer> api_to_init;
+ ASSIGN_OR_RETURN(
+ api_to_init,
+ core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
+ path_to_model,
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ kNumLiteThreads));
+ api_to_init->InitializeSentencepieceTokenizer(path_to_spmodel);
+ return api_to_init;
+}
+
+StatusOr<std::unique_ptr<QuestionAnswerer>>
+BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
+ const char* model_buffer_data, size_t model_buffer_size,
+ const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
+ std::unique_ptr<BertQuestionAnswerer> api_to_init;
+ ASSIGN_OR_RETURN(
+ api_to_init,
+ core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
+ model_buffer_data, model_buffer_size,
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
+ kNumLiteThreads));
+ api_to_init->InitializeSentencepieceTokenizerFromBinary(spmodel_buffer_data,
+ spmodel_buffer_size);
+ return api_to_init;
+}
+
+std::vector<QaAnswer> BertQuestionAnswerer::Answer(
+ const std::string& context, const std::string& question) {
+ // The BertQuestionAnswererer implementation for Preprocess() and
+ // Postprocess() never returns errors: just call value().
+ return Infer(context, question).value();
+}
+
+absl::Status BertQuestionAnswerer::Preprocess(
+ const std::vector<TfLiteTensor*>& input_tensors, const std::string& context,
+ const std::string& query) {
+ auto* input_tensor_metadatas =
+ GetMetadataExtractor()->GetInputTensorMetadata();
+ TfLiteTensor* ids_tensor =
+ input_tensor_metadatas
+ ? FindTensorByName(input_tensors, input_tensor_metadatas,
+ kIdsTensorName)
+ : input_tensors[0];
+ TfLiteTensor* mask_tensor =
+ input_tensor_metadatas
+ ? FindTensorByName(input_tensors, input_tensor_metadatas,
+ kMaskTensorName)
+ : input_tensors[1];
+ TfLiteTensor* segment_ids_tensor =
+ input_tensor_metadatas
+ ? FindTensorByName(input_tensors, input_tensor_metadatas,
+ kSegmentIdsTensorName)
+ : input_tensors[2];
+
+ token_to_orig_map_.clear();
+
+ // The orig_tokens is used for recovering the answer string from the index,
+ // while the processed_tokens is lower-cased and used to generate input of
+ // the model.
+ orig_tokens_ = absl::StrSplit(context, absl::ByChar(' '), absl::SkipEmpty());
+ std::vector<std::string> processed_tokens(orig_tokens_);
+
+ std::string processed_query = query;
+ if (kUseLowerCase) {
+ for (auto& token : processed_tokens) {
+ absl::AsciiStrToLower(&token);
+ }
+ absl::AsciiStrToLower(&processed_query);
+ }
+
+ TokenizerResult query_tokenize_results;
+ query_tokenize_results = tokenizer_->Tokenize(processed_query);
+
+ std::vector<std::string> query_tokens = query_tokenize_results.subwords;
+ if (query_tokens.size() > kMaxQueryLen) {
+ query_tokens.resize(kMaxQueryLen);
+ }
+
+ // Example:
+ // context: tokenize me please
+ // all_doc_tokens: token ##ize me plea ##se
+ // token_to_orig_index: [0, 0, 1, 2, 2]
+
+ std::vector<std::string> all_doc_tokens;
+ std::vector<int> token_to_orig_index;
+ for (size_t i = 0; i < processed_tokens.size(); i++) {
+ const std::string& token = processed_tokens[i];
+ std::vector<std::string> sub_tokens = tokenizer_->Tokenize(token).subwords;
+ for (const std::string& sub_token : sub_tokens) {
+ token_to_orig_index.emplace_back(i);
+ all_doc_tokens.emplace_back(sub_token);
+ }
+ }
+
+ // -3 accounts for [CLS], [SEP] and [SEP].
+ int max_context_len = kMaxSeqLen - query_tokens.size() - 3;
+ if (all_doc_tokens.size() > max_context_len) {
+ all_doc_tokens.resize(max_context_len);
+ }
+
+ std::vector<std::string> tokens;
+ tokens.reserve(3 + query_tokens.size() + all_doc_tokens.size());
+ std::vector<int> segment_ids;
+ segment_ids.reserve(kMaxSeqLen);
+
+ // Start of generating the features.
+ tokens.emplace_back("[CLS]");
+ segment_ids.emplace_back(0);
+
+ // For query input.
+ for (const auto& query_token : query_tokens) {
+ tokens.emplace_back(query_token);
+ segment_ids.emplace_back(0);
+ }
+
+ // For Separation.
+ tokens.emplace_back("[SEP]");
+ segment_ids.emplace_back(0);
+
+ // For Text Input.
+ for (int i = 0; i < all_doc_tokens.size(); i++) {
+ auto& doc_token = all_doc_tokens[i];
+ tokens.emplace_back(doc_token);
+ segment_ids.emplace_back(1);
+ token_to_orig_map_[tokens.size()] = token_to_orig_index[i];
+ }
+
+ // For ending mark.
+ tokens.emplace_back("[SEP]");
+ segment_ids.emplace_back(1);
+
+ std::vector<int> input_ids(tokens.size());
+ input_ids.reserve(kMaxSeqLen);
+ // Convert tokens back into ids
+ for (int i = 0; i < tokens.size(); i++) {
+ auto& token = tokens[i];
+ tokenizer_->LookupId(token, &input_ids[i]);
+ }
+
+ std::vector<int> input_mask;
+ input_mask.reserve(kMaxSeqLen);
+ input_mask.insert(input_mask.end(), tokens.size(), 1);
+
+ int zeros_to_pad = kMaxSeqLen - input_ids.size();
+ input_ids.insert(input_ids.end(), zeros_to_pad, 0);
+ input_mask.insert(input_mask.end(), zeros_to_pad, 0);
+ segment_ids.insert(segment_ids.end(), zeros_to_pad, 0);
+
+ // input_ids INT32[1, 384]
+ PopulateTensor(input_ids, ids_tensor);
+ // input_mask INT32[1, 384]
+ PopulateTensor(input_mask, mask_tensor);
+ // segment_ids INT32[1, 384]
+ PopulateTensor(segment_ids, segment_ids_tensor);
+
+ return absl::OkStatus();
+}
+
+StatusOr<std::vector<QaAnswer>> BertQuestionAnswerer::Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const std::string& /*lowercased_context*/,
+ const std::string& /*lowercased_query*/) {
+ auto* output_tensor_metadatas =
+ GetMetadataExtractor()->GetOutputTensorMetadata();
+
+ const TfLiteTensor* end_logits_tensor =
+ output_tensor_metadatas
+ ? FindTensorByName(output_tensors, output_tensor_metadatas,
+ kEndLogitsTensorName)
+ : output_tensors[0];
+ const TfLiteTensor* start_logits_tensor =
+ output_tensor_metadatas
+ ? FindTensorByName(output_tensors, output_tensor_metadatas,
+ kStartLogitsTensorName)
+ : output_tensors[1];
+
+ std::vector<float> end_logits;
+ std::vector<float> start_logits;
+
+ // end_logits FLOAT[1, 384]
+ PopulateVector(end_logits_tensor, &end_logits);
+ // start_logits FLOAT[1, 384]
+ PopulateVector(start_logits_tensor, &start_logits);
+
+ auto start_indices = ReverseSortIndices(start_logits);
+ auto end_indices = ReverseSortIndices(end_logits);
+
+ std::vector<QaAnswer::Pos> orig_results;
+ for (int start_index = 0; start_index < kPredictAnsNum; start_index++) {
+ for (int end_index = 0; end_index < kPredictAnsNum; end_index++) {
+ int start = start_indices[start_index];
+ int end = end_indices[end_index];
+
+ if (!token_to_orig_map_.contains(start + kOutputOffset) ||
+ !token_to_orig_map_.contains(end + kOutputOffset) || end < start ||
+ (end - start + 1) > kMaxAnsLen) {
+ continue;
+ }
+ orig_results.emplace_back(
+ QaAnswer::Pos(start, end, start_logits[start] + end_logits[end]));
+ }
+ }
+
+ std::sort(orig_results.begin(), orig_results.end());
+
+ std::vector<QaAnswer> answers;
+ for (int i = 0; i < orig_results.size() && i < kPredictAnsNum; i++) {
+ auto orig_pos = orig_results[i];
+ answers.emplace_back(
+ orig_pos.start > 0 ? ConvertIndexToString(orig_pos.start, orig_pos.end)
+ : "",
+ orig_pos);
+ }
+
+ return answers;
+}
+
+std::string BertQuestionAnswerer::ConvertIndexToString(int start, int end) {
+ int start_index = token_to_orig_map_[start + kOutputOffset];
+ int end_index = token_to_orig_map_[end + kOutputOffset];
+
+ return absl::StrJoin(orig_tokens_.begin() + start_index,
+ orig_tokens_.begin() + end_index + 1, " ");
+}
+
+absl::Status BertQuestionAnswerer::InitializeFromMetadata() {
+ const ProcessUnit* tokenizer_process_unit =
+ GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
+ if (tokenizer_process_unit == nullptr) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "No input process unit found from metadata.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+ ASSIGN_OR_RETURN(tokenizer_,
+ CreateTokenizerFromProcessUnit(tokenizer_process_unit,
+ GetMetadataExtractor()));
+ return absl::OkStatus();
+}
+
+void BertQuestionAnswerer::InitializeBertTokenizer(
+ const std::string& path_to_vocab) {
+ tokenizer_ = absl::make_unique<BertTokenizer>(path_to_vocab);
+}
+
+void BertQuestionAnswerer::InitializeBertTokenizerFromBinary(
+ const char* vocab_buffer_data, size_t vocab_buffer_size) {
+ tokenizer_ =
+ absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size);
+}
+
+void BertQuestionAnswerer::InitializeSentencepieceTokenizer(
+ const std::string& path_to_spmodel) {
+ tokenizer_ = absl::make_unique<SentencePieceTokenizer>(path_to_spmodel);
+}
+
+void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary(
+ const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
+ tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data,
+ spmodel_buffer_size);
+}
+
+} // namespace qa
+} // namespace text
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h
new file mode 100644
index 00000000..4c65dc00
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h
@@ -0,0 +1,170 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
+#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+#include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h"
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace qa {
+
+// BertQA task API, performs tokenization for models (BERT, Albert, etc.) in
+// preprocess and returns most possible answers.
+//
+// In particular, the branch of BERT models use WordPiece tokenizer, and the
+// branch of Albert models use SentencePiece tokenizer, respectively.
+//
+// Factory methods:
+// CreateFromFile(path_to_model_with_metadata)
+// CreateFromBuffer(model_with_metadata_buffer_data,
+// model_with_metadata_buffer_size)
+// CreateFromFd(file_descriptor_to_model_with_metadata)
+// Generic API to create the QuestionAnswerer for bert models with metadata
+// populated. The API expects a Bert based TFLite model with metadata
+// containing the following information:
+// - input_process_units for Wordpiece/Sentencepiece Tokenizer. Wordpiece
+// Tokenizer can be used for a MobileBert[0] model, Sentencepiece
+// Tokenizer Tokenizer can be used for an Albert[1] model
+// - 3 input tensors with names "ids", "mask" and "segment_ids"
+// - 2 output tensors with names "end_logits" and "start_logits"
+// [0]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
+// [1]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
+//
+// CreateBertQuestionAnswererFromFile(path_to_model, path_to_vocab)
+// Creates a BertQuestionAnswerer from TFLite model file and vocab file for
+// WordPiece tokenizer. Used in C++ environment.
+// One suitable model is:
+// https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
+//
+// CreateBertQuestionAnswererFromBuffer(model_buffer_data, model_buffer_size,
+// vocab_buffer_data, vocab_buffer_size)
+// Creates a BertQuestionAnswerer from TFLite model buffer and vocab file
+// buffer for WordPiece tokenizer. Used in Jave (JNI) environment.
+//
+// CreateAlbertQuestionAnswererFromFile(path_to_model, path_to_spmodel)
+// Creates an AlbertQuestionAnswerer from TFLite model file and
+// SentencePiece model file. Used in C++ environment.
+// One suitable model is:
+// https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
+//
+// CreateAlbertQuestionAnswererFromBuffer(model_buffer_data,
+// model_buffer_size,
+// spmodel_buffer_data,
+// spmodel_buffer_size)
+// Creates an AlbertQuestionAnswerer from TFLite model file buffer and
+// SentencePiece model file buffer. Used in Jave (JNI) environment.
+//
+
+class BertQuestionAnswerer : public QuestionAnswerer {
+ public:
+ // TODO(b/150904655): add support to parameterize.
+ static constexpr int kMaxQueryLen = 64;
+ static constexpr int kMaxSeqLen = 384;
+ static constexpr int kPredictAnsNum = 5;
+ static constexpr int kMaxAnsLen = 32;
+ // TODO(b/151954803): clarify the offset usage
+ static constexpr int kOutputOffset = 1;
+ static constexpr int kNumLiteThreads = 4;
+ static constexpr bool kUseLowerCase = true;
+
+ static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
+ CreateFromFile(const std::string& path_to_model_with_metadata);
+
+ static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
+ CreateFromBuffer(const char* model_with_metadata_buffer_data,
+ size_t model_with_metadata_buffer_size);
+
+ static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
+ CreateFromFd(int fd);
+
+ static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
+ CreateBertQuestionAnswererFromFile(const std::string& path_to_model,
+ const std::string& path_to_vocab);
+
+ static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
+ CreateBertQuestionAnswererFromBuffer(const char* model_buffer_data,
+ size_t model_buffer_size,
+ const char* vocab_buffer_data,
+ size_t vocab_buffer_size);
+
+ static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
+ CreateAlbertQuestionAnswererFromFile(const std::string& path_to_model,
+ const std::string& path_to_spmodel);
+
+ static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
+ CreateAlbertQuestionAnswererFromBuffer(const char* model_buffer_data,
+ size_t model_buffer_size,
+ const char* spmodel_buffer_data,
+ size_t spmodel_buffer_size);
+
+ explicit BertQuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine)
+ : QuestionAnswerer(std::move(engine)) {}
+
+ // Answers question based on the context. Could be empty if no answer was
+ // found from the given context.
+ std::vector<QaAnswer> Answer(const std::string& context,
+ const std::string& question) override;
+
+ private:
+ absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
+ const std::string& lowercased_context,
+ const std::string& lowercased_query) override;
+
+ tflite::support::StatusOr<std::vector<QaAnswer>> Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const std::string& lowercased_context,
+ const std::string& lowercased_query) override;
+
+ // Initialize API with a BertTokenizer from the vocabulary file.
+ void InitializeBertTokenizer(const std::string& path_to_vocab);
+ // Initialize API with a BertTokenizer from the vocabulary buffer.
+ void InitializeBertTokenizerFromBinary(const char* vocab_buffer_data,
+ size_t vocab_buffer_size);
+
+ // Initialize API with a SentencepieceTokenizer from the model file.
+ void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel);
+ // Initialize API with a SentencepieceTokenizer from the model buffer.
+ void InitializeSentencepieceTokenizerFromBinary(
+ const char* spmodel_buffer_data, size_t spmodel_buffer_size);
+
+ // Initialize the API with the tokenizer set in the metadata.
+ absl::Status InitializeFromMetadata();
+
+ std::string ConvertIndexToString(int start, int end);
+
+ std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
+ // Maps index of input token to index of untokenized word from original input.
+ absl::flat_hash_map<size_t, size_t> token_to_orig_map_;
+ // Original tokens of context.
+ std::vector<std::string> orig_tokens_;
+};
+
+} // namespace qa
+} // namespace text
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_
diff --git a/tensorflow_lite_support/cc/task/text/qa/question_answerer.h b/tensorflow_lite_support/cc/task/text/qa/question_answerer.h
new file mode 100644
index 00000000..f46a40c2
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/text/qa/question_answerer.h
@@ -0,0 +1,65 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace qa {
+
+// Struct for the Answer to QuestionAnswerer.
+struct QaAnswer {
+ // struct to represent the logit and offset of the answer related to context.
+ struct Pos {
+ Pos(int arg_start, int arg_end, float arg_logit)
+ : start(arg_start), end(arg_end), logit(arg_logit) {}
+ int start, end;
+ float logit;
+ bool operator<(const Pos& rhs) const { return rhs.logit < logit; }
+ };
+
+ QaAnswer(std::string arg_text, Pos arg_pos)
+ : text(std::move(arg_text)), pos(arg_pos) {}
+ std::string text;
+ Pos pos;
+};
+
+// Interface for an Question-Answer API.
+class QuestionAnswerer
+ : public core::BaseTaskApi<std::vector<QaAnswer>, const std::string&,
+ const std::string&> {
+ public:
+ explicit QuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine)
+ : BaseTaskApi(std::move(engine)) {}
+
+ virtual std::vector<QaAnswer> Answer(const std::string& context,
+ const std::string& question) = 0;
+};
+
+} // namespace qa
+} // namespace text
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_
diff --git a/tensorflow_lite_support/cc/task/vision/BUILD b/tensorflow_lite_support/cc/task/vision/BUILD
new file mode 100644
index 00000000..d426486f
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/BUILD
@@ -0,0 +1,108 @@
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "object_detector",
+ srcs = ["object_detector.cc"],
+ hdrs = ["object_detector.h"],
+ deps = [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:external_file_handler",
+ "//tensorflow_lite_support/cc/task/core:task_api_factory",
+ "//tensorflow_lite_support/cc/task/core:task_utils",
+ "//tensorflow_lite_support/cc/task/core:tflite_engine",
+ "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "//tensorflow_lite_support/cc/task/vision/core:label_map_item",
+ "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/core/api",
+ ],
+)
+
+cc_library(
+ name = "image_classifier",
+ srcs = ["image_classifier.cc"],
+ hdrs = ["image_classifier.h"],
+ deps = [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:external_file_handler",
+ "//tensorflow_lite_support/cc/task/core:task_api_factory",
+ "//tensorflow_lite_support/cc/task/core:task_utils",
+ "//tensorflow_lite_support/cc/task/core:tflite_engine",
+ "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api",
+ "//tensorflow_lite_support/cc/task/vision/core:classification_head",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "//tensorflow_lite_support/cc/task/vision/core:label_map_item",
+ "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils",
+ "//tensorflow_lite_support/cc/task/vision/utils:score_calibration",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/core/api",
+ ],
+)
+
+cc_library(
+ name = "image_segmenter",
+ srcs = ["image_segmenter.cc"],
+ hdrs = ["image_segmenter.h"],
+ deps = [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:external_file_handler",
+ "//tensorflow_lite_support/cc/task/core:task_api_factory",
+ "//tensorflow_lite_support/cc/task/core:task_utils",
+ "//tensorflow_lite_support/cc/task/core:tflite_engine",
+ "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "//tensorflow_lite_support/cc/task/vision/core:label_map_item",
+ "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/core/api",
+ ],
+)
diff --git a/tensorflow_lite_support/cc/task/vision/core/BUILD b/tensorflow_lite_support/cc/task/vision/core/BUILD
new file mode 100644
index 00000000..1df86cb9
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/core/BUILD
@@ -0,0 +1,81 @@
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(srcs = ["base_vision_task_api.h"])
+
+cc_library(
+ name = "base_vision_task_api",
+ hdrs = [
+ "base_vision_task_api.h",
+ ],
+ deps = [
+ ":frame_buffer",
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/task/core:base_task_api",
+ "//tensorflow_lite_support/cc/task/core:task_utils",
+ "//tensorflow_lite_support/cc/task/core:tflite_engine",
+ "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils",
+ "//tensorflow_lite_support/cc/task/vision/utils:image_tensor_specs",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/time",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ ],
+)
+
+cc_library(
+ name = "frame_buffer",
+ srcs = ["frame_buffer.cc"],
+ hdrs = ["frame_buffer.h"],
+ deps = [
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:any",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
+ name = "label_map_item",
+ srcs = ["label_map_item.cc"],
+ hdrs = ["label_map_item.h"],
+ deps = [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+cc_library(
+ name = "classification_head",
+ srcs = ["classification_head.cc"],
+ hdrs = ["classification_head.h"],
+ deps = [
+ ":label_map_item",
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/vision/utils:score_calibration",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h b/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
new file mode 100644
index 00000000..feb6b4a1
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h
@@ -0,0 +1,270 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_
+
+#include <array>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/time/clock.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
+#include "tensorflow_lite_support/cc/task/core/task_utils.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Base class providing common logic for vision models.
+template <class OutputType>
+class BaseVisionTaskApi
+ : public tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
+ const BoundingBox&> {
+ public:
+ explicit BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine)
+ : tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
+ const BoundingBox&>(std::move(engine)) {
+ }
+ // BaseVisionTaskApi is neither copyable nor movable.
+ BaseVisionTaskApi(const BaseVisionTaskApi&) = delete;
+ BaseVisionTaskApi& operator=(const BaseVisionTaskApi&) = delete;
+
+ // Number of bytes required for 8-bit per pixel RGB color space.
+ static constexpr int kRgbPixelBytes = 3;
+
+ // Sets the ProcessEngine used for image pre-processing. Must be called before
+ // any inference is performed. Can be called between inferences to override
+ // the current process engine.
+ void SetProcessEngine(const FrameBufferUtils::ProcessEngine& process_engine) {
+ frame_buffer_utils_ = FrameBufferUtils::Create(process_engine);
+ }
+
+ protected:
+ using tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&,
+ const BoundingBox&>::engine_;
+
+ // Checks input tensor and metadata (if any) are valid, or return an error
+ // otherwise. This must be called once at initialization time, before running
+ // inference, as it is a prerequisite for `Preprocess`.
+ // Note: the underlying interpreter and metadata extractor are assumed to be
+ // already successfully initialized before calling this method.
+ virtual absl::Status CheckAndSetInputs() {
+ ASSIGN_OR_RETURN(
+ ImageTensorSpecs input_specs,
+ BuildInputImageTensorSpecs(*engine_->interpreter(),
+ *engine_->metadata_extractor()));
+
+ if (input_specs.color_space != tflite::ColorSpaceType_RGB) {
+ return tflite::support::CreateStatusWithPayload(
+ absl::StatusCode::kUnimplemented,
+ "BaseVisionTaskApi only supports RGB color space for now.");
+ }
+
+ input_specs_ = absl::make_unique<ImageTensorSpecs>(input_specs);
+
+ return absl::OkStatus();
+ }
+
+ // Performs image preprocessing on the input frame buffer over the region of
+ // interest so that it fits model requirements (e.g. upright 224x224 RGB) and
+ // populate the corresponding input tensor. This is performed by (in this
+ // order):
+ // - cropping the frame buffer to the region of interest (which, in most
+ // cases, just covers the entire input image),
+ // - resizing it (with bilinear interpolation, aspect-ratio *not* preserved)
+ // to the dimensions of the model input tensor,
+ // - converting it to the colorspace of the input tensor (i.e. RGB, which is
+ // the only supported colorspace for now),
+ // - rotating it according to its `Orientation` so that inference is performed
+ // on an "upright" image.
+ //
+ // IMPORTANT: as a consequence of cropping occurring first, the provided
+ // region of interest is expressed in the unrotated frame of reference
+ // coordinates system, i.e. in `[0, frame_buffer.width) x [0,
+ // frame_buffer.height)`, which are the dimensions of the underlying
+ // `frame_buffer` data before any `Orientation` flag gets applied. Also, the
+ // region of interest is not clamped, so this method will return a non-ok
+ // status if the region is out of these bounds.
+ absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
+ const FrameBuffer& frame_buffer,
+ const BoundingBox& roi) override {
+ if (input_specs_ == nullptr) {
+ return tflite::support::CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ "Uninitialized input tensor specs: CheckAndSetInputs must be called "
+ "at initialization time.");
+ }
+
+ if (frame_buffer_utils_ == nullptr) {
+ return tflite::support::CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ "Uninitialized frame buffer utils: SetProcessEngine must be called "
+ "at initialization time.");
+ }
+
+ if (input_tensors.size() != 1) {
+ return tflite::support::CreateStatusWithPayload(
+ absl::StatusCode::kInternal, "A single input tensor is expected.");
+ }
+
+ // Input data to be normalized (if needed) and used for inference. In most
+ // cases, this is the result of image preprocessing. In case no image
+ // preprocessing is needed (see below), this points to the input frame
+ // buffer raw data.
+ const uint8* input_data;
+ size_t input_data_byte_size;
+
+ // Optional buffers in case image preprocessing is needed.
+ std::unique_ptr<FrameBuffer> preprocessed_frame_buffer;
+ std::vector<uint8> preprocessed_data;
+
+ if (IsImagePreprocessingNeeded(frame_buffer, roi)) {
+ // Preprocess input image to fit model requirements.
+ // For now RGB is the only color space supported, which is ensured by
+ // `CheckAndSetInputs`.
+ FrameBuffer::Dimension to_buffer_dimension = {input_specs_->image_width,
+ input_specs_->image_height};
+ input_data_byte_size =
+ GetBufferByteSize(to_buffer_dimension, FrameBuffer::Format::kRGB);
+ preprocessed_data.resize(input_data_byte_size / sizeof(uint8), 0);
+ input_data = preprocessed_data.data();
+
+ FrameBuffer::Plane preprocessed_plane = {
+ /*buffer=*/preprocessed_data.data(),
+ /*stride=*/{input_specs_->image_width * kRgbPixelBytes,
+ kRgbPixelBytes}};
+ preprocessed_frame_buffer = FrameBuffer::Create(
+ {preprocessed_plane}, to_buffer_dimension, FrameBuffer::Format::kRGB,
+ FrameBuffer::Orientation::kTopLeft);
+
+ RETURN_IF_ERROR(frame_buffer_utils_->Preprocess(
+ frame_buffer, roi, preprocessed_frame_buffer.get()));
+ } else {
+ // Input frame buffer already targets model requirements: skip image
+ // preprocessing. For RGB, the data is always stored in a single plane.
+ input_data = frame_buffer.plane(0).buffer;
+ input_data_byte_size = frame_buffer.plane(0).stride.row_stride_bytes *
+ frame_buffer.dimension().height;
+ }
+
+ // Then normalize pixel data (if needed) and populate the input tensor.
+ switch (input_specs_->tensor_type) {
+ case kTfLiteUInt8:
+ if (input_tensors[0]->bytes != input_data_byte_size) {
+ return tflite::support::CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ "Size mismatch or unsupported padding bytes between pixel data "
+ "and input tensor.");
+ }
+ // No normalization required: directly populate data.
+ tflite::task::core::PopulateTensor(
+ input_data, input_data_byte_size / sizeof(uint8), input_tensors[0]);
+ break;
+ case kTfLiteFloat32: {
+ if (input_tensors[0]->bytes / sizeof(float) !=
+ input_data_byte_size / sizeof(uint8)) {
+ return tflite::support::CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ "Size mismatch or unsupported padding bytes between pixel data "
+ "and input tensor.");
+ }
+ // Normalize and populate.
+ float* normalized_input_data =
+ tflite::task::core::AssertAndReturnTypedTensor<float>(
+ input_tensors[0]);
+ const tflite::task::vision::NormalizationOptions&
+ normalization_options = input_specs_->normalization_options.value();
+ if (normalization_options.num_values == 1) {
+ float mean_value = normalization_options.mean_values[0];
+ float inv_std_value = (1.0f / normalization_options.std_values[0]);
+ for (int i = 0; i < input_data_byte_size / sizeof(uint8);
+ i++, input_data++, normalized_input_data++) {
+ *normalized_input_data =
+ inv_std_value * (static_cast<float>(*input_data) - mean_value);
+ }
+ } else {
+ std::array<float, 3> inv_std_values = {
+ 1.0f / normalization_options.std_values[0],
+ 1.0f / normalization_options.std_values[1],
+ 1.0f / normalization_options.std_values[2]};
+ for (int i = 0; i < input_data_byte_size / sizeof(uint8);
+ i++, input_data++, normalized_input_data++) {
+ *normalized_input_data = inv_std_values[i % 3] *
+ (static_cast<float>(*input_data) -
+ normalization_options.mean_values[i % 3]);
+ }
+ }
+ break;
+ }
+ case kTfLiteInt8:
+ return tflite::support::CreateStatusWithPayload(
+ absl::StatusCode::kUnimplemented,
+ "kTfLiteInt8 input type is not implemented yet.");
+ default:
+ return tflite::support::CreateStatusWithPayload(
+ absl::StatusCode::kInternal, "Unexpected input tensor type.");
+ }
+
+ return absl::OkStatus();
+ }
+
+ // Utils for input image preprocessing (resizing, colorspace conversion, etc).
+ std::unique_ptr<FrameBufferUtils> frame_buffer_utils_;
+
+ // Parameters related to the input tensor which represents an image.
+ std::unique_ptr<ImageTensorSpecs> input_specs_;
+
+ private:
+ // Returns false if image preprocessing could be skipped, true otherwise.
+ bool IsImagePreprocessingNeeded(const FrameBuffer& frame_buffer,
+ const BoundingBox& roi) {
+ // Is crop required?
+ if (roi.origin_x() != 0 || roi.origin_y() != 0 ||
+ roi.width() != frame_buffer.dimension().width ||
+ roi.height() != frame_buffer.dimension().height) {
+ return true;
+ }
+
+ // Are image transformations required?
+ if (frame_buffer.orientation() != FrameBuffer::Orientation::kTopLeft ||
+ frame_buffer.format() != FrameBuffer::Format::kRGB ||
+ frame_buffer.dimension().width != input_specs_->image_width ||
+ frame_buffer.dimension().height != input_specs_->image_height) {
+ return true;
+ }
+
+ return false;
+ }
+};
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_
diff --git a/tensorflow_lite_support/cc/task/vision/core/classification_head.cc b/tensorflow_lite_support/cc/task/vision/core/classification_head.cc
new file mode 100644
index 00000000..962cb34b
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/core/classification_head.cc
@@ -0,0 +1,114 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_lite_support/cc/task/vision/core/classification_head.h"
+
+#include "absl/status/status.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+using ::absl::StatusCode;
+using ::tflite::metadata::ModelMetadataExtractor;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+
+StatusOr<ClassificationHead> BuildClassificationHead(
+ const tflite::metadata::ModelMetadataExtractor& metadata_extractor,
+ const tflite::TensorMetadata& output_tensor_metadata,
+ absl::string_view display_names_locale) {
+ ClassificationHead head;
+ if (output_tensor_metadata.name() != nullptr) {
+ head.name = output_tensor_metadata.name()->str();
+ }
+
+ // Build label map, if present.
+ const std::string labels_filename =
+ ModelMetadataExtractor::FindFirstAssociatedFileName(
+ output_tensor_metadata,
+ tflite::AssociatedFileType_TENSOR_AXIS_LABELS);
+ if (!labels_filename.empty()) {
+ ASSIGN_OR_RETURN(absl::string_view labels_file,
+ metadata_extractor.GetAssociatedFile(labels_filename));
+ const std::string display_names_filename =
+ ModelMetadataExtractor::FindFirstAssociatedFileName(
+ output_tensor_metadata,
+ tflite::AssociatedFileType_TENSOR_AXIS_LABELS,
+ display_names_locale);
+ absl::string_view display_names_file;
+ if (!display_names_filename.empty()) {
+ ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
+ display_names_filename));
+ }
+ ASSIGN_OR_RETURN(head.label_map_items,
+ BuildLabelMapFromFiles(labels_file, display_names_file));
+ }
+
+ // Set score threshold, if present.
+ ASSIGN_OR_RETURN(const tflite::ProcessUnit* score_thresholding_process_unit,
+ ModelMetadataExtractor::FindFirstProcessUnit(
+ output_tensor_metadata,
+ tflite::ProcessUnitOptions_ScoreThresholdingOptions));
+ if (score_thresholding_process_unit != nullptr) {
+ head.score_threshold =
+ score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
+ ->global_score_threshold();
+ }
+
+ // Build score calibration parameters, if present.
+ ASSIGN_OR_RETURN(const tflite::ProcessUnit* score_calibration_process_unit,
+ ModelMetadataExtractor::FindFirstProcessUnit(
+ output_tensor_metadata,
+ tflite::ProcessUnitOptions_ScoreCalibrationOptions));
+ if (score_calibration_process_unit != nullptr) {
+ if (labels_filename.empty()) {
+ return CreateStatusWithPayload(
+ StatusCode::kNotFound,
+ "Using ScoreCalibrationOptions requires a label map to be provided "
+ "as TENSOR_AXIS_LABELS associated file.",
+ TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
+ }
+ const std::string score_calibration_filename =
+ ModelMetadataExtractor::FindFirstAssociatedFileName(
+ output_tensor_metadata,
+ tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION);
+ if (score_calibration_filename.empty()) {
+ return CreateStatusWithPayload(
+ StatusCode::kNotFound,
+ "Found ScoreCalibrationOptions but missing required associated "
+ "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.",
+ TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
+ }
+ ASSIGN_OR_RETURN(
+ absl::string_view score_calibration_file,
+ metadata_extractor.GetAssociatedFile(score_calibration_filename));
+ ASSIGN_OR_RETURN(SigmoidCalibrationParameters sigmoid_params,
+ BuildSigmoidCalibrationParams(
+ *score_calibration_process_unit
+ ->options_as_ScoreCalibrationOptions(),
+ score_calibration_file, head.label_map_items));
+ head.calibration_params = sigmoid_params;
+ }
+
+ return head;
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/core/classification_head.h b/tensorflow_lite_support/cc/task/vision/core/classification_head.h
new file mode 100644
index 00000000..07cd8b9b
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/core/classification_head.h
@@ -0,0 +1,110 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// A single classifier head for an image classifier model, associated with a
+// corresponding output tensor.
+struct ClassificationHead {
+ ClassificationHead() : score_threshold(0) {}
+
+ explicit ClassificationHead(
+ const std::vector<tflite::task::vision::LabelMapItem>&& label_map_items)
+ : label_map_items(label_map_items), score_threshold(0) {}
+
+ // An optional name that usually indicates what this set of classes represent,
+ // e.g. "flowers".
+ std::string name;
+ // The label map representing the list of supported classes, aka labels.
+ //
+ // This must be in direct correspondence with the associated output tensor,
+ // i.e.:
+ //
+ // - The number of classes must match with the dimension of the corresponding
+ // output tensor,
+ // - The i-th item in the label map is assumed to correspond to the i-th
+ // output value in the output tensor.
+ //
+ // This requires to put in place dedicated sanity checks before running
+ // inference.
+ std::vector<tflite::task::vision::LabelMapItem> label_map_items;
+ // Recommended score threshold typically in [0,1[. Classification results with
+ // a score below this value are considered low-confidence and should be
+ // rejected from returned results.
+ float score_threshold;
+ // Optional score calibration parameters (one set of parameters per class in
+ // the label map). This is primarily meant for multi-label classifiers made of
+ // independent sigmoids.
+ //
+ // Such parameters are usually tuned so that calibrated scores can be compared
+ // to a default threshold common to all classes to achieve a given amount of
+ // precision.
+ //
+ // Example: 60% precision for threshold = 0.5.
+ absl::optional<tflite::task::vision::SigmoidCalibrationParameters>
+ calibration_params;
+};
+
+// Builds a classification head using the provided metadata extractor, for the
+// given output tensor metadata. Returns an error in case the head cannot be
+// built (e.g. missing associated file for score calibration parameters).
+//
+// Optionally it is possible to specify which locale should be used (e.g. "en")
+// to fill the label map display names, if any, and provided the corresponding
+// associated file is present in the metadata. If no locale is specified, or if
+// there is no associated file for the provided locale, display names are just
+// left empty and no error is returned.
+//
+// E.g. (metatada displayed in JSON format below):
+//
+// ...
+// "associated_files": [
+// {
+// "name": "labels.txt",
+// "type": "TENSOR_AXIS_LABELS"
+// },
+// {
+// "name": "labels-en.txt",
+// "type": "TENSOR_AXIS_LABELS",
+// "locale": "en"
+// },
+// ...
+//
+// See metadata schema TENSOR_AXIS_LABELS for more details.
+tflite::support::StatusOr<ClassificationHead> BuildClassificationHead(
+ const tflite::metadata::ModelMetadataExtractor& metadata_extractor,
+ const tflite::TensorMetadata& output_tensor_metadata,
+ absl::string_view display_names_locale = absl::string_view());
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_
diff --git a/tensorflow_lite_support/cc/task/vision/core/frame_buffer.cc b/tensorflow_lite_support/cc/task/vision/core/frame_buffer.cc
new file mode 100644
index 00000000..02658cd9
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/core/frame_buffer.cc
@@ -0,0 +1,179 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+using ::tflite::support::StatusOr;
+
+namespace {
+
+// Returns whether the input `format` is a supported YUV format.
+bool IsSupportedYuvFormat(FrameBuffer::Format format) {
+ return format == FrameBuffer::Format::kNV21 ||
+ format == FrameBuffer::Format::kNV12 ||
+ format == FrameBuffer::Format::kYV12 ||
+ format == FrameBuffer::Format::kYV21;
+}
+
+// Returns supported 1-plane FrameBuffer in YuvData structure.
+StatusOr<FrameBuffer::YuvData> GetYuvDataFromOnePlaneFrameBuffer(
+ const FrameBuffer& source) {
+ if (!IsSupportedYuvFormat(source.format())) {
+ return absl::InvalidArgumentError(
+ "The source FrameBuffer format is not part of YUV420 family.");
+ }
+
+ FrameBuffer::YuvData result;
+ const int y_buffer_size =
+ source.plane(0).stride.row_stride_bytes * source.dimension().height;
+ const int uv_buffer_size =
+ ((source.plane(0).stride.row_stride_bytes + 1) / 2) *
+ ((source.dimension().height + 1) / 2);
+ result.y_buffer = source.plane(0).buffer;
+ result.y_row_stride = source.plane(0).stride.row_stride_bytes;
+ result.uv_row_stride = result.y_row_stride;
+
+ if (source.format() == FrameBuffer::Format::kNV21) {
+ result.v_buffer = result.y_buffer + y_buffer_size;
+ result.u_buffer = result.v_buffer + 1;
+ result.uv_pixel_stride = 2;
+ // If y_row_stride equals to the frame width and is an odd value,
+ // uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride.
+ if (result.y_row_stride == source.dimension().width &&
+ result.y_row_stride % 2 == 1) {
+ result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2;
+ }
+ } else if (source.format() == FrameBuffer::Format::kNV12) {
+ result.u_buffer = result.y_buffer + y_buffer_size;
+ result.v_buffer = result.u_buffer + 1;
+ result.uv_pixel_stride = 2;
+ // If y_row_stride equals to the frame width and is an odd value,
+ // uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride.
+ if (result.y_row_stride == source.dimension().width &&
+ result.y_row_stride % 2 == 1) {
+ result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2;
+ }
+ } else if (source.format() == FrameBuffer::Format::kYV21) {
+ result.u_buffer = result.y_buffer + y_buffer_size;
+ result.v_buffer = result.u_buffer + uv_buffer_size;
+ result.uv_pixel_stride = 1;
+ result.uv_row_stride = (result.y_row_stride + 1) / 2;
+ } else if (source.format() == FrameBuffer::Format::kYV12) {
+ result.v_buffer = result.y_buffer + y_buffer_size;
+ result.u_buffer = result.v_buffer + uv_buffer_size;
+ result.uv_pixel_stride = 1;
+ result.uv_row_stride = (result.y_row_stride + 1) / 2;
+ }
+ return result;
+}
+
+// Returns supported 2-plane FrameBuffer in YuvData structure.
+StatusOr<FrameBuffer::YuvData> GetYuvDataFromTwoPlaneFrameBuffer(
+ const FrameBuffer& source) {
+ if (source.format() != FrameBuffer::Format::kNV12 &&
+ source.format() != FrameBuffer::Format::kNV21) {
+ return absl::InvalidArgumentError("Unsupported YUV planar format.");
+ }
+
+ FrameBuffer::YuvData result;
+ // Y plane
+ result.y_buffer = source.plane(0).buffer;
+ // All plane strides
+ result.y_row_stride = source.plane(0).stride.row_stride_bytes;
+ result.uv_row_stride = source.plane(1).stride.row_stride_bytes;
+ result.uv_pixel_stride = 2;
+
+ if (source.format() == FrameBuffer::Format::kNV12) {
+ // Y and UV interleaved format
+ result.u_buffer = source.plane(1).buffer;
+ result.v_buffer = result.u_buffer + 1;
+ } else {
+ // Y and VU interleaved format
+ result.v_buffer = source.plane(1).buffer;
+ result.u_buffer = result.v_buffer + 1;
+ }
+ return result;
+}
+
+// Returns supported 3-plane FrameBuffer in YuvData structure. Note that NV21
+// and NV12 are included in the supported Yuv formats. Technically, NV21 and
+// NV12 should not be described by the 3-plane format. Historically, NV21 is
+// used loosely such that it can also be used to describe YV21 format. For
+// backwards compatibility, FrameBuffer supports NV21/NV12 with 3-plane format
+// but such usage is discouraged
+StatusOr<FrameBuffer::YuvData> GetYuvDataFromThreePlaneFrameBuffer(
+ const FrameBuffer& source) {
+ if (!IsSupportedYuvFormat(source.format())) {
+ return absl::InvalidArgumentError(
+ "The source FrameBuffer format is not part of YUV420 family.");
+ }
+
+ if (source.plane(1).stride.row_stride_bytes !=
+ source.plane(2).stride.row_stride_bytes ||
+ source.plane(1).stride.pixel_stride_bytes !=
+ source.plane(2).stride.pixel_stride_bytes) {
+ return absl::InternalError("Unsupported YUV planar format.");
+ }
+ FrameBuffer::YuvData result;
+ if (source.format() == FrameBuffer::Format::kNV21 ||
+ source.format() == FrameBuffer::Format::kYV12) {
+ // Y follow by VU order. The VU chroma planes can be interleaved or
+ // planar.
+ result.y_buffer = source.plane(0).buffer;
+ result.v_buffer = source.plane(1).buffer;
+ result.u_buffer = source.plane(2).buffer;
+ result.y_row_stride = source.plane(0).stride.row_stride_bytes;
+ result.uv_row_stride = source.plane(1).stride.row_stride_bytes;
+ result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes;
+ } else {
+ // Y follow by UV order. The UV chroma planes can be interleaved or
+ // planar.
+ result.y_buffer = source.plane(0).buffer;
+ result.u_buffer = source.plane(1).buffer;
+ result.v_buffer = source.plane(2).buffer;
+ result.y_row_stride = source.plane(0).stride.row_stride_bytes;
+ result.uv_row_stride = source.plane(1).stride.row_stride_bytes;
+ result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes;
+ }
+ return result;
+}
+
+} // namespace
+
+StatusOr<FrameBuffer::YuvData> FrameBuffer::GetYuvDataFromFrameBuffer(
+ const FrameBuffer& source) {
+ if (!IsSupportedYuvFormat(source.format())) {
+ return absl::InvalidArgumentError(
+ "The source FrameBuffer format is not part of YUV420 family.");
+ }
+
+ if (source.plane_count() == 1) {
+ return GetYuvDataFromOnePlaneFrameBuffer(source);
+ } else if (source.plane_count() == 2) {
+ return GetYuvDataFromTwoPlaneFrameBuffer(source);
+ } else if (source.plane_count() == 3) {
+ return GetYuvDataFromThreePlaneFrameBuffer(source);
+ }
+ return absl::InvalidArgumentError(
+ "The source FrameBuffer must be consisted by 1, 2, or 3 planes");
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
new file mode 100644
index 00000000..31589f38
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h
@@ -0,0 +1,296 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_FRAME_BUFFER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_FRAME_BUFFER_H_
+
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "absl/types/any.h"
+#include "absl/types/optional.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// A `FrameBuffer` provides a view into the provided backing buffer (e.g. camera
+// frame or still image) with buffer format information. FrameBuffer doesn't
+// take ownership of the provided backing buffer. The caller is responsible to
+// manage the backing buffer lifecycle for the lifetime of the FrameBuffer.
+//
+// FrameBuffer also provides a tagging system to allow the client of FrameBuffer
+// to attach arbitrary tags to an instance. The tagging system is meant for
+// small set of metadata. FrameBuffer does not use the tags in anyway. The
+// uniqueness of the tag is only guarded by the uniqueness of the key.
+// The tag is useful when the uniqueness of a FrameBuffer can not be determined
+// by its associated metadata. For example, there are two FrameBuffer instances
+// with the same metadata (size dimension, orientation, format, etc) but one is
+// generated through cropping of Frame A and another is generated by resizing of
+// Frame A. The client can tag one of the generated FrameBuffer to distinguish
+// the difference.
+//
+// Examples:
+//
+// // Create an metadata instance with no backing buffer.
+// auto buffer = FrameBuffer::Create(/*planes=*/{}, dimension, kRGBA,
+// KTopLeft);
+//
+// // Create an RGBA instance with backing buffer on single plane.
+// FrameBuffer::Plane plane = {rgba_buffer, /*stride=*/{dimension.width * 4,
+// 4}}; auto buffer = FrameBuffer::Create({plane}, dimension, kRGBA, kTopLeft);
+//
+// // Create an YUV instance with planar backing buffer.
+// FrameBuffer::Plane y_plane = {y_buffer, /*stride=*/{dimension.width , 1}};
+// FrameBuffer::Plane uv_plane = {u_buffer, /*stride=*/{dimension.width, 2}};
+// auto buffer = FrameBuffer::Create({y_plane, uv_plane}, dimension, kNV21,
+// kLeftTop);
+//
+// // Add / retrieve tags from a FrameBuffer instance.
+// buffer.InsertTag("my_special_key", 1);
+// buffer.GetTag("my_special_key");
+//
+class FrameBuffer {
+ public:
+ // Colorspace formats.
+ enum class Format { kRGBA, kRGB, kNV12, kNV21, kYV12, kYV21, kGRAY };
+
+ // Stride information.
+ struct Stride {
+ // The row stride in bytes. This is the distance between the start pixels of
+ // two consecutive rows in the image.
+ int row_stride_bytes;
+ // This is the distance between two consecutive pixel values in a row of
+ // pixels in bytes. It may be larger than the size of a single pixel to
+ // account for interleaved image data or padded formats.
+ int pixel_stride_bytes;
+ };
+
+ // YUV data structure.
+ struct YuvData {
+ const uint8* y_buffer;
+ const uint8* u_buffer;
+ const uint8* v_buffer;
+ // Y buffer row stride in bytes.
+ int y_row_stride;
+ // U/V buffer row stride in bytes.
+ int uv_row_stride;
+ // U/V pixel stride in bytes. This is the distance between two consecutive
+ // u/v pixel values in a row.
+ int uv_pixel_stride;
+ };
+
+ // FrameBuffer content orientation follows EXIF specification. The name of
+ // each enum value defines the position of the 0th row and the 0th column of
+ // the image content. See http://jpegclub.org/exif_orientation.html for
+ // details.
+ enum class Orientation {
+ kTopLeft = 1,
+ kTopRight = 2,
+ kBottomRight = 3,
+ kBottomLeft = 4,
+ kLeftTop = 5,
+ kRightTop = 6,
+ kRightBottom = 7,
+ kLeftBottom = 8
+ };
+
+ // Plane encapsulates buffer and stride information.
+ struct Plane {
+ const uint8* buffer;
+ Stride stride;
+ };
+
+ // Dimension information for the whole frame or a cropped portion of it.
+ struct Dimension {
+ // The width dimension in pixel unit.
+ int width;
+ // The height dimension in pixel unit.
+ int height;
+
+ bool operator==(const Dimension& other) const {
+ return width == other.width && height == other.height;
+ }
+
+ bool operator!=(const Dimension& other) const {
+ return width != other.width || height != other.height;
+ }
+
+ bool operator>=(const Dimension& other) const {
+ return width >= other.width && height >= other.height;
+ }
+
+ bool operator<=(const Dimension& other) const {
+ return width <= other.width && height <= other.height;
+ }
+
+ // Swaps width and height.
+ void Swap() {
+ using std::swap;
+ swap(width, height);
+ }
+
+ // Returns area represented by width * height.
+ int Size() const { return width * height; }
+ };
+
+ // Factory method for creating a FrameBuffer object from row-major backing
+ // buffers. In a streaming use case (e.g continuous camera stream), the
+ // timestamp can be used as an ID to identify a frame.
+ static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
+ Dimension dimension, Format format,
+ Orientation orientation,
+ absl::Time timestamp) {
+ return absl::make_unique<FrameBuffer>(planes, dimension, format,
+ orientation, timestamp);
+ }
+
+ // Factory method for creating a FrameBuffer object from row-major movable
+ // backing buffers. In a streaming use case (e.g continuous camera stream),
+ // the timestamp can be used as an ID to identify a frame.
+ static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
+ Dimension dimension, Format format,
+ Orientation orientation,
+ absl::Time timestamp) {
+ return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
+ orientation, timestamp);
+ }
+
+ // Factory method for creating a FrameBuffer object from row-major backing
+ // buffers. By default this method set the timestamp to now. This method is
+ // more suitable for processing use case that does not need to re-identify
+ // this buffer.
+ static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
+ Dimension dimension, Format format,
+ Orientation orientation) {
+ return absl::make_unique<FrameBuffer>(planes, dimension, format,
+ orientation, absl::Now());
+ }
+
+ // Factory method for creating a FrameBuffer object from movable row-major
+ // backing buffers. By default this method set the timestamp to now. This
+ // method is more suitable for processing use case that does not need to
+ // re-identify this buffer.
+ static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
+ Dimension dimension, Format format,
+ Orientation orientation) {
+ return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
+ orientation, absl::Now());
+ }
+
+ // Returns YuvData which contains the Y, U, and V buffer and their
+ // stride info from the input `source` FrameBuffer which is in the YUV family
+ // formats (e.g NV12, NV21, YV12, and YV21).
+ static tflite::support::StatusOr<YuvData> GetYuvDataFromFrameBuffer(
+ const FrameBuffer& source);
+
+ // Builds a FrameBuffer object from a row-major backing buffer.
+ //
+ // The FrameBuffer does not take ownership of the backing buffer. The backing
+ // buffer is read-only and the caller is responsible for maintaining the
+ // backing buffer lifecycle for the lifetime of FrameBuffer.
+ FrameBuffer(const std::vector<Plane>& planes, Dimension dimension,
+ Format format, Orientation orientation, absl::Time timestamp)
+ : planes_(planes),
+ dimension_(dimension),
+ format_(format),
+ orientation_(orientation),
+ timestamp_(timestamp) {}
+
+ // Builds a FrameBuffer object from a movable row-major backing buffer.
+ //
+ // The FrameBuffer does not take ownership of the backing buffer. The backing
+ // buffer is read-only and the caller is responsible for maintaining the
+ // backing buffer lifecycle for the lifetime of FrameBuffer.
+ FrameBuffer(std::vector<Plane>&& planes, Dimension dimension, Format format,
+ Orientation orientation, absl::Time timestamp)
+ : planes_(std::move(planes)),
+ dimension_(dimension),
+ format_(format),
+ orientation_(orientation),
+ timestamp_(timestamp) {}
+
+ // Returns number of planes.
+ const int plane_count() const { return planes_.size(); }
+
+ // Returns plane indexed by the input `index`.
+ const Plane plane(int index) const {
+ if (index > -1 && index < planes_.size()) {
+ return planes_[index];
+ }
+ return {};
+ }
+
+ // Returns the tag associated to the tag_key.
+ absl::any GetTag(const std::string& tag_key) const {
+ auto iter = tags_.find(tag_key);
+ if (iter != tags_.end()) {
+ return iter->second;
+ }
+ return absl::any();
+ }
+
+ // Inserts or updates the tags map with key value pair (tag_key, tag_value).
+ void InsertOrUpdateTag(const std::string& tag_key, absl::any tag_value) {
+ tags_[tag_key] = std::move(tag_value);
+ }
+
+ // Inserts the key value pair (tag_key, tag_value) into tags map. If the
+ // tag_key already exists, an internal error will return.
+ absl::Status InsertTag(const std::string& tag_key, absl::any tag_value) {
+ auto iter = tags_.emplace(tag_key, tag_value);
+ if (iter.second) {
+ return absl::OkStatus();
+ }
+ return absl::InternalError(absl::StrCat(
+ "tag_key already exists in tags.tag_key was not inserted: ", tag_key));
+ }
+
+ // Returns FrameBuffer dimension.
+ const Dimension dimension() const { return dimension_; }
+
+ // Returns FrameBuffer format.
+ const Format format() const { return format_; }
+
+ // Returns FrameBuffer orientation.
+ const Orientation orientation() const { return orientation_; }
+
+ // Returns FrameBuffer timestamp.
+ const absl::Time timestamp() const { return timestamp_; }
+
+ private:
+ std::vector<Plane> planes_;
+ std::map<std::string, absl::any> tags_;
+ Dimension dimension_;
+ Format format_;
+ Orientation orientation_;
+ absl::Time timestamp_;
+};
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_FRAME_BUFFER_H_
diff --git a/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc b/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
new file mode 100644
index 00000000..75b1fc60
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
@@ -0,0 +1,128 @@
+
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow_lite_support/cc/common.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+using ::absl::StatusCode;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+
+StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
+ absl::string_view labels_file, absl::string_view display_names_file) {
+ if (labels_file.empty()) {
+ return CreateStatusWithPayload(StatusCode::kInvalidArgument,
+ "Expected non-empty labels file.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ std::vector<absl::string_view> labels = absl::StrSplit(labels_file, '\n');
+ // In most cases, there is an empty line (i.e. newline character) at the end
+ // of the file that needs to be ignored. In such a situation, StrSplit() will
+ // produce a vector with an empty string as final element. Also note that in
+ // case `labels_file` is entirely empty, StrSplit() will produce a vector with
+ // one single empty substring, so there's no out-of-range risk here.
+ if (labels[labels.size() - 1].empty()) {
+ labels.pop_back();
+ }
+
+ std::vector<LabelMapItem> label_map_items;
+ label_map_items.reserve(labels.size());
+ for (int i = 0; i < labels.size(); ++i) {
+ label_map_items.emplace_back(LabelMapItem{.name = std::string(labels[i])});
+ }
+
+ if (!display_names_file.empty()) {
+ std::vector<std::string> display_names =
+ absl::StrSplit(display_names_file, '\n');
+ // In most cases, there is an empty line (i.e. newline character) at the end
+ // of the file that needs to be ignored. See above.
+ if (display_names[display_names.size() - 1].empty()) {
+ display_names.pop_back();
+ }
+ if (display_names.size() != labels.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Mismatch between number of labels (%d) and display names (%d).",
+ labels.size(), display_names.size()),
+ TfLiteSupportStatus::kMetadataNumLabelsMismatchError);
+ }
+ for (int i = 0; i < display_names.size(); ++i) {
+ label_map_items[i].display_name = display_names[i];
+ }
+ }
+ return label_map_items;
+}
+
+absl::Status LabelHierarchy::InitializeFromLabelMap(
+ std::vector<LabelMapItem> label_map_items) {
+ parents_map_.clear();
+ for (const LabelMapItem& label : label_map_items) {
+ for (const std::string& child_name : label.child_name) {
+ parents_map_[child_name].insert(label.name);
+ }
+ }
+ if (parents_map_.empty()) {
+ return CreateStatusWithPayload(StatusCode::kInvalidArgument,
+ "Input labelmap is not hierarchical: there "
+ "is no parent-child relationship.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ return absl::OkStatus();
+}
+
+bool LabelHierarchy::HaveAncestorDescendantRelationship(
+ const std::string& ancestor_name,
+ const std::string& descendant_name) const {
+ absl::flat_hash_set<std::string> ancestors;
+ GetAncestors(descendant_name, &ancestors);
+ return ancestors.contains(ancestor_name);
+}
+
+absl::flat_hash_set<std::string> LabelHierarchy::GetParents(
+ const std::string& name) const {
+ absl::flat_hash_set<std::string> parents;
+ auto it = parents_map_.find(name);
+ if (it != parents_map_.end()) {
+ for (const std::string& parent_name : it->second) {
+ parents.insert(parent_name);
+ }
+ }
+ return parents;
+}
+
+void LabelHierarchy::GetAncestors(
+ const std::string& name,
+ absl::flat_hash_set<std::string>* ancestors) const {
+ const absl::flat_hash_set<std::string> parents = GetParents(name);
+ for (const std::string& parent_name : parents) {
+ auto it = ancestors->insert(parent_name);
+ if (it.second) {
+ GetAncestors(parent_name, ancestors);
+ }
+ }
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/core/label_map_item.h b/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
new file mode 100644
index 00000000..3ac9a000
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/core/label_map_item.h
@@ -0,0 +1,95 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Structure mapping a numerical class index output to a Knowledge Graph entity
+// ID or any other string label representing this class. Optionally it is
+// possible to specify an additional display name (in a given language) which is
+// typically used for display purposes.
+struct LabelMapItem {
+ // E.g. name = "/m/02xwb"
+ std::string name;
+ // E.g. display_name = "Fruit"
+ std::string display_name;
+ // Optional list of children (e.g. subcategories) used to represent a
+ // hierarchy.
+ std::vector<std::string> child_name;
+};
+
+// Builds a label map from labels and (optional) display names file contents,
+// both expected to contain one label per line. Those are typically obtained
+// from TFLite Model Metadata TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS
+// associated files.
+// Returns an error e.g. if there's a mismatch between the number of labels and
+// display names.
+tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
+ absl::string_view labels_file, absl::string_view display_names_file);
+
+// A class that represents a hierarchy of labels as specified in a label map.
+//
+// For example, it is useful to determine if one label is a descendant of
+// another label or not. This can be used to implement labels pruning based on
+// hierarchy, e.g. if both "fruit" and "banana" have been inferred by a given
+// classifier model prune "fruit" from the final results as "banana" is a more
+// fine-grained descendant.
+class LabelHierarchy {
+ public:
+ LabelHierarchy() = default;
+
+ // Initializes the hierarchy of labels from a given label map vector. Returns
+ // an error status in case of failure, typically if the input label map does
+ // not contain any hierarchical relations between labels.
+ absl::Status InitializeFromLabelMap(
+ std::vector<LabelMapItem> label_map_items);
+
+ // Returns true if `descendant_name` is a descendant of `ancestor_name` in the
+ // hierarchy of labels. Invalid names, i.e. names which do not exist in the
+ // label map used at initialization time, are ignored.
+ bool HaveAncestorDescendantRelationship(
+ const std::string& ancestor_name,
+ const std::string& descendant_name) const;
+
+ private:
+ // Retrieve and return all parent names, if any, for the input label name.
+ absl::flat_hash_set<std::string> GetParents(const std::string& name) const;
+
+ // Retrieve all ancestor names, if any, for the input label name.
+ void GetAncestors(const std::string& name,
+ absl::flat_hash_set<std::string>* ancestors) const;
+
+ // Label name (key) to parent names (value) direct mapping.
+ absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>>
+ parents_map_;
+};
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_
diff --git a/tensorflow_lite_support/cc/task/vision/image_classifier.cc b/tensorflow_lite_support/cc/task/vision/image_classifier.cc
new file mode 100644
index 00000000..378797b4
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/image_classifier.cc
@@ -0,0 +1,572 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
+#include "tensorflow_lite_support/cc/task/core/task_utils.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+namespace {
+
+using ::absl::StatusCode;
+using ::tflite::metadata::ModelMetadataExtractor;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+using ::tflite::task::core::AssertAndReturnTypedTensor;
+using ::tflite::task::core::TaskAPIFactory;
+using ::tflite::task::core::TfLiteEngine;
+
+// Default score value used as a fallback for classes that (1) have no score
+// calibration data or (2) have a very low confident uncalibrated score, i.e.
+// lower than the `min_uncalibrated_score` threshold.
+//
+// (1) This happens when the ScoreCalibration does not cover all the classes
+// listed in the label map. This can be used to enforce the blacklisting of
+// given classes so that they are never returned.
+//
+// (2) This is an optional threshold provided part of the calibration data. It
+// is used to mitigate false alarms on some classes.
+//
+// In both cases, a class that gets assigned a score of -1 is never returned as
+// it gets discarded by the `score_threshold` check (see post-processing logic).
+constexpr float kDefaultCalibratedScore = -1.0f;
+
+// Calibrated scores should be in the [0, 1] range, otherwise an error is
+// returned at post-processing time.
+constexpr float kMinCalibratedScore = 0.0f;
+constexpr float kMaxCalibratedScore = 1.0f;
+
+} // namespace
+
+/* static */
+StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::CreateFromOptions(
+ const ImageClassifierOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver) {
+ RETURN_IF_ERROR(SanityCheckOptions(options));
+
+ // Copy options to ensure the ExternalFile outlives the constructed object.
+ auto options_copy = absl::make_unique<ImageClassifierOptions>(options);
+
+ ASSIGN_OR_RETURN(auto image_classifier,
+ TaskAPIFactory::CreateFromExternalFileProto<ImageClassifier>(
+ &options_copy->model_file_with_metadata(),
+ std::move(resolver), options_copy->num_threads()));
+
+ RETURN_IF_ERROR(image_classifier->Init(std::move(options_copy)));
+
+ return image_classifier;
+}
+
+/* static */
+absl::Status ImageClassifier::SanityCheckOptions(
+ const ImageClassifierOptions& options) {
+ if (!options.has_model_file_with_metadata()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Missing mandatory `model_file_with_metadata` field",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.max_results() == 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Invalid `max_results` option: value must be != 0",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.score_threshold() < 0 || options.score_threshold() >= 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "`score_threshold` out of range: %f. Valid range is [0,1[.",
+ options.score_threshold()),
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.class_name_whitelist_size() > 0 &&
+ options.class_name_blacklist_size() > 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "`class_name_whitelist` and `class_name_blacklist` are mutually "
+ "exclusive options.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.num_threads() == 0 || options.num_threads() < -1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "`num_threads` must be greater than 0 or equal to -1.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::Init(
+ std::unique_ptr<ImageClassifierOptions> options) {
+ // Set options.
+ options_ = std::move(options);
+
+ // Perform pre-initialization actions (by default, sets the process engine for
+ // image pre-processing to kLibyuv as a sane default).
+ RETURN_IF_ERROR(PreInit());
+
+ // Sanity check and set inputs and outputs.
+ RETURN_IF_ERROR(CheckAndSetInputs());
+ RETURN_IF_ERROR(CheckAndSetOutputs());
+
+ // Initialize class whitelisting/blacklisting, if any.
+ RETURN_IF_ERROR(CheckAndSetClassNameSet());
+
+ // Perform final initialization (by default, initialize score calibration
+ // parameters, if any).
+ RETURN_IF_ERROR(PostInit());
+
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::PreInit() {
+ SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); }
+
+absl::Status ImageClassifier::CheckAndSetOutputs() {
+ num_outputs_ = TfLiteEngine::OutputCount(engine_->interpreter());
+
+ // Perform sanity checks and extract metadata.
+ const ModelMetadataExtractor* metadata_extractor =
+ engine_->metadata_extractor();
+
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
+ output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata();
+
+ // Loop over output tensors metadata, if any.
+ // Note: models with no output tensor metadata at all are supported.
+ if (output_tensor_metadata != nullptr) {
+ int num_output_tensors = output_tensor_metadata->size();
+
+ if (num_outputs_ != num_output_tensors) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Mismatch between number of output tensors (%d) and "
+ "output tensors "
+ "metadata (%d).",
+ num_outputs_, num_output_tensors),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+
+ for (int i = 0; i < num_output_tensors; ++i) {
+ const tflite::TensorMetadata* output_tensor =
+ output_tensor_metadata->Get(i);
+
+ ASSIGN_OR_RETURN(
+ ClassificationHead head,
+ BuildClassificationHead(*metadata_extractor, *output_tensor,
+ options_->display_names_locale()));
+
+ classification_heads_.emplace_back(std::move(head));
+ }
+ }
+
+ // If classifier heads are not set, build default ones based on model
+ // introspection. This happens if a model with partial or no metadata was
+ // provided through the `model_file_with_metadata` options field.
+ if (classification_heads_.empty()) {
+ classification_heads_.reserve(num_outputs_);
+ for (int output_index = 0; output_index < num_outputs_; ++output_index) {
+ classification_heads_.emplace_back(ClassificationHead{});
+ }
+ }
+
+ if (num_outputs_ != classification_heads_.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Got %d classifier head(s), expected %d according to "
+ "the label map.",
+ num_outputs_, classification_heads_.size()),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+
+ int num_quantized_outputs = 0;
+ for (int i = 0; i < num_outputs_; ++i) {
+ const TfLiteTensor* output_tensor =
+ TfLiteEngine::GetOutput(engine_->interpreter(), i);
+ const int num_dimensions = output_tensor->dims->size;
+ if (num_dimensions == 4) {
+ if (output_tensor->dims->data[1] != 1 ||
+ output_tensor->dims->data[2] != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Unexpected WxH sizes for output index %d: got "
+ "%dx%d, expected 1x1.",
+ i, output_tensor->dims->data[2],
+ output_tensor->dims->data[1]),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ } else if (num_dimensions != 2) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Unexpected number of dimensions for output index %d: got %dD, "
+ "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, "
+ "H=1).",
+ i, num_dimensions),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ if (output_tensor->dims->data[0] != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("The output array is expected to have a batch size "
+ "of 1. Got %d for output index %d.",
+ output_tensor->dims->data[0], i),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ int num_classes = output_tensor->dims->data[num_dimensions - 1];
+ // If label map is not set, build a default one based on model
+ // introspection. This happens if a model with partial or no metadata was
+ // provided through the `model_file_with_metadata` options field.
+ if (classification_heads_[i].label_map_items.empty()) {
+ classification_heads_[i].label_map_items.reserve(num_classes);
+ for (int class_index = 0; class_index < num_classes; ++class_index) {
+ classification_heads_[i].label_map_items.emplace_back(LabelMapItem{});
+ }
+ }
+ int num_label_map_items = classification_heads_[i].label_map_items.size();
+ if (num_classes != num_label_map_items) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Got %d class(es) for output index %d, expected %d "
+ "according to the label map.",
+ output_tensor->dims->data[num_dimensions - 1], i,
+ num_label_map_items),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+ if (output_tensor->type == kTfLiteUInt8) {
+ num_quantized_outputs++;
+ } else if (output_tensor->type != kTfLiteFloat32) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Type mismatch for output tensor %s. Requested one "
+ "of these types: "
+ "kTfLiteUint8/kTfLiteFloat32, got %s.",
+ output_tensor->name,
+ TfLiteTypeGetName(output_tensor->type)),
+ TfLiteSupportStatus::kInvalidOutputTensorTypeError);
+ }
+ }
+
+ if (num_quantized_outputs > 0 && num_quantized_outputs != num_outputs_) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Got %d quantized output(s), expected %d (i.e. all "
+ "provided outputs must be quantized).",
+ num_quantized_outputs, num_outputs_),
+ TfLiteSupportStatus::kInvalidOutputTensorTypeError);
+ }
+ has_uint8_outputs_ = (num_quantized_outputs > 0);
+
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::CheckAndSetClassNameSet() {
+ // Exit early if no blacklist/whitelist.
+ if (options_->class_name_blacklist_size() == 0 &&
+ options_->class_name_whitelist_size() == 0) {
+ return absl::OkStatus();
+ }
+
+ // Before processing class names whitelist or blacklist from the input options
+ // create a set with _all_ known class names from the label map(s).
+ absl::flat_hash_set<std::string> all_class_names;
+ int head_index = 0;
+ for (const auto& head : classification_heads_) {
+ absl::flat_hash_set<std::string> head_class_names;
+ for (const auto& item : head.label_map_items) {
+ if (!item.name.empty()) {
+ head_class_names.insert(item.name);
+ }
+ }
+ if (head_class_names.empty()) {
+ std::string name = head.name;
+ if (name.empty()) {
+ name = absl::StrFormat("#%d", head_index);
+ }
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Using `class_name_whitelist` or `class_name_blacklist` "
+ "requires labels to be present but none was found for "
+ "classification head: %s",
+ name),
+ TfLiteSupportStatus::kMetadataMissingLabelsError);
+ }
+ all_class_names.insert(head_class_names.begin(), head_class_names.end());
+ head_index++;
+ }
+
+ class_name_set_.is_whitelist = options_->class_name_whitelist_size() > 0;
+ const auto& class_names = class_name_set_.is_whitelist
+ ? options_->class_name_whitelist()
+ : options_->class_name_blacklist();
+
+ // Note: duplicate or unknown classes are just ignored.
+ class_name_set_.values.clear();
+ for (const auto& class_name : class_names) {
+ if (!all_class_names.contains(class_name)) {
+ continue;
+ }
+ class_name_set_.values.insert(class_name);
+ }
+
+ if (class_name_set_.values.empty()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Invalid class names specified via `class_name_%s`: none match "
+ "with model labels.",
+ class_name_set_.is_whitelist ? "whitelist" : "blacklist"),
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::InitScoreCalibrations() {
+ score_calibrations_.clear();
+ score_calibrations_.resize(classification_heads_.size());
+
+ for (int i = 0; i < classification_heads_.size(); ++i) {
+ if (!classification_heads_[i].calibration_params.has_value()) {
+ continue;
+ }
+
+ // Use a specific default score instead of the one specified by default in
+ // cc/task/vision/utils/score_calibration.h. See `kDefaultCalibratedScore`
+ // documentation for more details.
+ classification_heads_[i].calibration_params->default_score =
+ kDefaultCalibratedScore;
+
+ score_calibrations_[i] = absl::make_unique<ScoreCalibration>();
+ if (score_calibrations_[i] == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal, "Could not create score calibration object.");
+ }
+
+ RETURN_IF_ERROR(score_calibrations_[i]->InitializeFromParameters(
+ classification_heads_[i].calibration_params.value()));
+ }
+
+ return absl::OkStatus();
+}
+
+StatusOr<ClassificationResult> ImageClassifier::Classify(
+ const FrameBuffer& frame_buffer) {
+ BoundingBox roi;
+ roi.set_width(frame_buffer.dimension().width);
+ roi.set_height(frame_buffer.dimension().height);
+ return Classify(frame_buffer, roi);
+}
+
+StatusOr<ClassificationResult> ImageClassifier::Classify(
+ const FrameBuffer& frame_buffer, const BoundingBox& roi) {
+ return InferWithFallback(frame_buffer, roi);
+}
+
+StatusOr<ClassificationResult> ImageClassifier::Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
+ if (output_tensors.size() != num_outputs_) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Expected %d output tensors, found %d", num_outputs_,
+ output_tensors.size()));
+ }
+
+ ClassificationResult result;
+ std::vector<std::pair<int, float>> score_pairs;
+
+ for (int i = 0; i < num_outputs_; ++i) {
+ auto* classifications = result.add_classifications();
+ classifications->set_head_index(i);
+
+ const auto& head = classification_heads_[i];
+ score_pairs.clear();
+ score_pairs.reserve(head.label_map_items.size());
+
+ const TfLiteTensor* output_tensor = output_tensors[i];
+ if (has_uint8_outputs_) {
+ const uint8* output_data =
+ AssertAndReturnTypedTensor<uint8>(output_tensor);
+ for (int j = 0; j < head.label_map_items.size(); ++j) {
+ score_pairs.emplace_back(j, output_tensor->params.scale *
+ (static_cast<int>(output_data[j]) -
+ output_tensor->params.zero_point));
+ }
+ } else {
+ const float* output_data =
+ AssertAndReturnTypedTensor<float>(output_tensor);
+ for (int j = 0; j < head.label_map_items.size(); ++j) {
+ score_pairs.emplace_back(j, output_data[j]);
+ }
+ }
+
+ // Optional score calibration.
+ if (score_calibrations_[i] != nullptr) {
+ for (auto& score_pair : score_pairs) {
+ const std::string& class_name =
+ head.label_map_items[score_pair.first].name;
+ score_pair.second = score_calibrations_[i]->ComputeCalibratedScore(
+ class_name, score_pair.second);
+ if (score_pair.second > kMaxCalibratedScore) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("calibrated score is too high: got %f, expected "
+ "%f as maximum.",
+ score_pair.second, kMaxCalibratedScore));
+ }
+ if (score_pair.second != kDefaultCalibratedScore &&
+ score_pair.second < kMinCalibratedScore) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("calibrated score is too low: got %f, expected "
+ "%f as minimum.",
+ score_pair.second, kMinCalibratedScore));
+ }
+ }
+ }
+
+ int num_results =
+ options_->max_results() >= 0
+ ? std::min(static_cast<int>(head.label_map_items.size()),
+ options_->max_results())
+ : head.label_map_items.size();
+ float score_threshold = options_->has_score_threshold()
+ ? options_->score_threshold()
+ : head.score_threshold;
+
+ if (class_name_set_.values.empty()) {
+ // Partially sort in descending order (higher score is better).
+ absl::c_partial_sort(
+ score_pairs, score_pairs.begin() + num_results,
+ [](const std::pair<int, float>& a, const std::pair<int, float>& b) {
+ return a.second > b.second;
+ });
+
+ for (int j = 0; j < num_results; ++j) {
+ float score = score_pairs[j].second;
+ if (score < score_threshold) {
+ break;
+ }
+ auto* cl = classifications->add_classes();
+ cl->set_index(score_pairs[j].first);
+ cl->set_score(score);
+ }
+ } else {
+ // Sort in descending order (higher score is better).
+ absl::c_sort(score_pairs, [](const std::pair<int, float>& a,
+ const std::pair<int, float>& b) {
+ return a.second > b.second;
+ });
+
+ for (int j = 0; j < head.label_map_items.size(); ++j) {
+ float score = score_pairs[j].second;
+ if (score < score_threshold ||
+ classifications->classes_size() >= num_results) {
+ break;
+ }
+
+ const int class_index = score_pairs[j].first;
+ const std::string& class_name = head.label_map_items[class_index].name;
+
+ bool class_name_found = class_name_set_.values.contains(class_name);
+
+ if ((!class_name_found && class_name_set_.is_whitelist) ||
+ (class_name_found && !class_name_set_.is_whitelist)) {
+ continue;
+ }
+
+ auto* cl = classifications->add_classes();
+ cl->set_index(class_index);
+ cl->set_score(score);
+ }
+ }
+ }
+
+ RETURN_IF_ERROR(FillResultsFromLabelMaps(&result));
+
+ return result;
+}
+
+absl::Status ImageClassifier::FillResultsFromLabelMaps(
+ ClassificationResult* result) {
+ for (int i = 0; i < result->classifications_size(); ++i) {
+ Classifications* classifications = result->mutable_classifications(i);
+ int head_index = classifications->head_index();
+ if (head_index < 0 || head_index >= classification_heads_.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Invalid head index (%d) with respect to total "
+ "number of classification heads (%d).",
+ head_index, classification_heads_.size()),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+ const std::vector<LabelMapItem>& label_map_items =
+ classification_heads_[head_index].label_map_items;
+ for (int j = 0; j < classifications->classes_size(); ++j) {
+ Class* current_class = classifications->mutable_classes(j);
+ int current_class_index = current_class->index();
+ if (current_class_index < 0 ||
+ current_class_index >= label_map_items.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Invalid class index (%d) with respect to label "
+ "map size (%d) for head #%d.",
+ current_class_index, label_map_items.size(),
+ head_index),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+ const std::string& name = label_map_items[current_class_index].name;
+ if (!name.empty()) {
+ current_class->set_class_name(name);
+ }
+ const std::string& display_name =
+ label_map_items[current_class_index].display_name;
+ if (!display_name.empty()) {
+ current_class->set_display_name(display_name);
+ }
+ }
+ }
+ return absl::OkStatus();
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/image_classifier.h b/tensorflow_lite_support/cc/task/vision/image_classifier.h
new file mode 100644
index 00000000..edd90931
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/image_classifier.h
@@ -0,0 +1,182 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_
+
+#include <memory>
+#include <vector>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h"
+#include "tensorflow_lite_support/cc/task/vision/core/classification_head.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Performs classification on images.
+//
+// The API expects a TFLite model with optional, but strongly recommended,
+// TFLite Model Metadata.
+//
+// Input tensor:
+// (kTfLiteUInt8/kTfLiteFloat32)
+// - image input of size `[batch x height x width x channels]`.
+// - batch inference is not supported (`batch` is required to be 1).
+// - only RGB inputs are supported (`channels` is required to be 3).
+// - if type is kTfLiteFloat32, NormalizationOptions are required to be
+// attached to the metadata for input normalization.
+// At least one output tensor with:
+// (kTfLiteUInt8/kTfLiteFloat32)
+// - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or
+// `[1 x 1 x 1 x N]`
+// - optional (but recommended) label map(s) as AssociatedFile-s with type
+// TENSOR_AXIS_LABELS, containing one label per line. The first such
+// AssociatedFile (if any) is used to fill the `class_name` field of the
+// results. The `display_name` field is filled from the AssociatedFile (if
+// any) whose locale matches the `display_names_locale` field of the
+// `ImageClassifierOptions` used at creation time ("en" by default, i.e.
+// English). If none of these are available, only the `index` field of the
+// results will be filled.
+//
+// An example of such model can be found at:
+// https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1
+//
+// A CLI demo tool is available for easily trying out this API, and provides
+// example usage. See:
+// examples/task/vision/desktop/image_classifier_demo.cc
+class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> {
+ public:
+ using BaseVisionTaskApi::BaseVisionTaskApi;
+
+ // Creates an ImageClassifier from the provided options. A non-default
+ // OpResolver can be specified in order to support custom Ops or specify a
+ // subset of built-in Ops.
+ static tflite::support::StatusOr<std::unique_ptr<ImageClassifier>>
+ CreateFromOptions(
+ const ImageClassifierOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+
+ // Performs actual classification on the provided FrameBuffer.
+ //
+ // The FrameBuffer can be of any size and any of the supported formats, i.e.
+ // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before
+ // inference in order to (and in this order):
+ // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to
+ // the dimensions of the model input tensor,
+ // - convert it to the colorspace of the input tensor (i.e. RGB, which is the
+ // only supported colorspace for now),
+ // - rotate it according to its `Orientation` so that inference is performed
+ // on an "upright" image.
+ tflite::support::StatusOr<ClassificationResult> Classify(
+ const FrameBuffer& frame_buffer);
+
+ // Same as above, except that the classification is performed based on the
+ // input region of interest. Cropping according to this region of interest is
+ // prepended to the pre-processing operations.
+ //
+ // IMPORTANT: as a consequence of cropping occurring first, the provided
+ // region of interest is expressed in the unrotated frame of reference
+ // coordinates system, i.e. in `[0, frame_buffer.width) x [0,
+ // frame_buffer.height)`, which are the dimensions of the underlying
+ // `frame_buffer` data before any `Orientation` flag gets applied. Also, the
+ // region of interest is not clamped, so this method will return a non-ok
+ // status if the region is out of these bounds.
+ tflite::support::StatusOr<ClassificationResult> Classify(
+ const FrameBuffer& frame_buffer, const BoundingBox& roi);
+
+ protected:
+ // The options used to build this ImageClassifier.
+ std::unique_ptr<ImageClassifierOptions> options_;
+
+ // The list of classification heads associated with the corresponding output
+ // tensors. Built from TFLite Model Metadata.
+ std::vector<ClassificationHead> classification_heads_;
+
+ // Post-processing to transform the raw model outputs into classification
+ // results.
+ tflite::support::StatusOr<ClassificationResult> Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
+
+ // Performs sanity checks on the provided ImageClassifierOptions.
+ static absl::Status SanityCheckOptions(const ImageClassifierOptions& options);
+
+ // Initializes the ImageClassifier from the provided ImageClassifierOptions,
+ // whose ownership is transferred to this object.
+ absl::Status Init(std::unique_ptr<ImageClassifierOptions> options);
+
+ // Performs pre-initialization actions.
+ virtual absl::Status PreInit();
+ // Performs post-initialization actions.
+ virtual absl::Status PostInit();
+
+ private:
+ // Performs sanity checks on the model outputs and extracts their metadata.
+ absl::Status CheckAndSetOutputs();
+
+ // Performs sanity checks on the class whitelist/blacklist and forms the class
+ // name set.
+ absl::Status CheckAndSetClassNameSet();
+
+ // Initializes the score calibration parameters based on corresponding TFLite
+ // Model Metadata, if any.
+ absl::Status InitScoreCalibrations();
+
+ // Given a ClassificationResult object containing class indices, fills the
+ // name and display name from the label map(s).
+ absl::Status FillResultsFromLabelMaps(ClassificationResult* result);
+
+ // The number of output tensors. This corresponds to the number of
+ // classification heads.
+ int num_outputs_;
+ // Whether the model features quantized inference type (QUANTIZED_UINT8). This
+ // is currently detected by checking if all output tensors data type is uint8.
+ bool has_uint8_outputs_;
+
+ // Set of whitelisted or blacklisted class names.
+ struct ClassNameSet {
+ absl::flat_hash_set<std::string> values;
+ bool is_whitelist;
+ };
+
+ // Whitelisted or blacklisted class names based on provided options at
+ // construction time. These are used to filter out results during
+ // post-processing.
+ ClassNameSet class_name_set_;
+
+ // List of score calibration parameters, if any. Built from TFLite Model
+ // Metadata.
+ std::vector<std::unique_ptr<ScoreCalibration>> score_calibrations_;
+};
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_
diff --git a/tensorflow_lite_support/cc/task/vision/image_segmenter.cc b/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
new file mode 100644
index 00000000..4523b662
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/image_segmenter.cc
@@ -0,0 +1,427 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
+
+#include <algorithm>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
+#include "tensorflow_lite_support/cc/task/core/task_utils.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+namespace {
+
+using ::absl::StatusCode;
+using ::tflite::TensorMetadata;
+using ::tflite::metadata::ModelMetadataExtractor;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+using ::tflite::task::core::AssertAndReturnTypedTensor;
+using ::tflite::task::core::TaskAPIFactory;
+using ::tflite::task::core::TfLiteEngine;
+
+// The maximum number of labels allowed in the labelmap. This is because so far
+// segmentation masks are stored with 8 bit per pixel (flattened byte array).
+constexpr uint32 kMaxNumClasses = 256;
+
+// TODO(b/)
+// The colormap used to fill `ColoredLabel`-s, as a flattened array of 256 {R,
+// G, B} components.
+constexpr uint8 kColorMap[768] = {
+ 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128,
+ 128, 0, 128, 0, 128, 128, 128, 128, 128, 64, 0, 0, 192, 0, 0,
+ 64, 128, 0, 192, 128, 0, 64, 0, 128, 192, 0, 128, 64, 128, 128,
+ 192, 128, 128, 0, 64, 0, 128, 64, 0, 0, 192, 0, 128, 192, 0,
+ 0, 64, 128, 128, 64, 128, 0, 192, 128, 128, 192, 128, 64, 64, 0,
+ 192, 64, 0, 64, 192, 0, 192, 192, 0, 64, 64, 128, 192, 64, 128,
+ 64, 192, 128, 192, 192, 128, 0, 0, 64, 128, 0, 64, 0, 128, 64,
+ 128, 128, 64, 0, 0, 192, 128, 0, 192, 0, 128, 192, 128, 128, 192,
+ 64, 0, 64, 192, 0, 64, 64, 128, 64, 192, 128, 64, 64, 0, 192,
+ 192, 0, 192, 64, 128, 192, 192, 128, 192, 0, 64, 64, 128, 64, 64,
+ 0, 192, 64, 128, 192, 64, 0, 64, 192, 128, 64, 192, 0, 192, 192,
+ 128, 192, 192, 64, 64, 64, 192, 64, 64, 64, 192, 64, 192, 192, 64,
+ 64, 64, 192, 192, 64, 192, 64, 192, 192, 192, 192, 192, 32, 0, 0,
+ 160, 0, 0, 32, 128, 0, 160, 128, 0, 32, 0, 128, 160, 0, 128,
+ 32, 128, 128, 160, 128, 128, 96, 0, 0, 224, 0, 0, 96, 128, 0,
+ 224, 128, 0, 96, 0, 128, 224, 0, 128, 96, 128, 128, 224, 128, 128,
+ 32, 64, 0, 160, 64, 0, 32, 192, 0, 160, 192, 0, 32, 64, 128,
+ 160, 64, 128, 32, 192, 128, 160, 192, 128, 96, 64, 0, 224, 64, 0,
+ 96, 192, 0, 224, 192, 0, 96, 64, 128, 224, 64, 128, 96, 192, 128,
+ 224, 192, 128, 32, 0, 64, 160, 0, 64, 32, 128, 64, 160, 128, 64,
+ 32, 0, 192, 160, 0, 192, 32, 128, 192, 160, 128, 192, 96, 0, 64,
+ 224, 0, 64, 96, 128, 64, 224, 128, 64, 96, 0, 192, 224, 0, 192,
+ 96, 128, 192, 224, 128, 192, 32, 64, 64, 160, 64, 64, 32, 192, 64,
+ 160, 192, 64, 32, 64, 192, 160, 64, 192, 32, 192, 192, 160, 192, 192,
+ 96, 64, 64, 224, 64, 64, 96, 192, 64, 224, 192, 64, 96, 64, 192,
+ 224, 64, 192, 96, 192, 192, 224, 192, 192, 0, 32, 0, 128, 32, 0,
+ 0, 160, 0, 128, 160, 0, 0, 32, 128, 128, 32, 128, 0, 160, 128,
+ 128, 160, 128, 64, 32, 0, 192, 32, 0, 64, 160, 0, 192, 160, 0,
+ 64, 32, 128, 192, 32, 128, 64, 160, 128, 192, 160, 128, 0, 96, 0,
+ 128, 96, 0, 0, 224, 0, 128, 224, 0, 0, 96, 128, 128, 96, 128,
+ 0, 224, 128, 128, 224, 128, 64, 96, 0, 192, 96, 0, 64, 224, 0,
+ 192, 224, 0, 64, 96, 128, 192, 96, 128, 64, 224, 128, 192, 224, 128,
+ 0, 32, 64, 128, 32, 64, 0, 160, 64, 128, 160, 64, 0, 32, 192,
+ 128, 32, 192, 0, 160, 192, 128, 160, 192, 64, 32, 64, 192, 32, 64,
+ 64, 160, 64, 192, 160, 64, 64, 32, 192, 192, 32, 192, 64, 160, 192,
+ 192, 160, 192, 0, 96, 64, 128, 96, 64, 0, 224, 64, 128, 224, 64,
+ 0, 96, 192, 128, 96, 192, 0, 224, 192, 128, 224, 192, 64, 96, 64,
+ 192, 96, 64, 64, 224, 64, 192, 224, 64, 64, 96, 192, 192, 96, 192,
+ 64, 224, 192, 192, 224, 192, 32, 32, 0, 160, 32, 0, 32, 160, 0,
+ 160, 160, 0, 32, 32, 128, 160, 32, 128, 32, 160, 128, 160, 160, 128,
+ 96, 32, 0, 224, 32, 0, 96, 160, 0, 224, 160, 0, 96, 32, 128,
+ 224, 32, 128, 96, 160, 128, 224, 160, 128, 32, 96, 0, 160, 96, 0,
+ 32, 224, 0, 160, 224, 0, 32, 96, 128, 160, 96, 128, 32, 224, 128,
+ 160, 224, 128, 96, 96, 0, 224, 96, 0, 96, 224, 0, 224, 224, 0,
+ 96, 96, 128, 224, 96, 128, 96, 224, 128, 224, 224, 128, 32, 32, 64,
+ 160, 32, 64, 32, 160, 64, 160, 160, 64, 32, 32, 192, 160, 32, 192,
+ 32, 160, 192, 160, 160, 192, 96, 32, 64, 224, 32, 64, 96, 160, 64,
+ 224, 160, 64, 96, 32, 192, 224, 32, 192, 96, 160, 192, 224, 160, 192,
+ 32, 96, 64, 160, 96, 64, 32, 224, 64, 160, 224, 64, 32, 96, 192,
+ 160, 96, 192, 32, 224, 192, 160, 224, 192, 96, 96, 64, 224, 96, 64,
+ 96, 224, 64, 224, 224, 64, 96, 96, 192, 224, 96, 192, 96, 224, 192,
+ 224, 224, 192};
+
+StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
+ const ModelMetadataExtractor& metadata_extractor,
+ const TensorMetadata& tensor_metadata, absl::string_view locale) {
+ const std::string labels_filename =
+ ModelMetadataExtractor::FindFirstAssociatedFileName(
+ tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS);
+ if (labels_filename.empty()) {
+ return std::vector<LabelMapItem>();
+ }
+ ASSIGN_OR_RETURN(absl::string_view labels_file,
+ metadata_extractor.GetAssociatedFile(labels_filename));
+ const std::string display_names_filename =
+ ModelMetadataExtractor::FindFirstAssociatedFileName(
+ tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS,
+ locale);
+ absl::string_view display_names_file = nullptr;
+ if (!display_names_filename.empty()) {
+ ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
+ display_names_filename));
+ }
+ return BuildLabelMapFromFiles(labels_file, display_names_file);
+}
+
+} // namespace
+
+/* static */
+absl::Status ImageSegmenter::SanityCheckOptions(
+ const ImageSegmenterOptions& options) {
+ if (!options.has_model_file_with_metadata()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Missing mandatory `model_file_with_metadata` field",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.output_type() == ImageSegmenterOptions::UNSPECIFIED) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "ImageSegmenterOptions: `output_type` must not be UNSPECIFIED",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.num_threads() == 0 || options.num_threads() < -1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "`num_threads` must be greater than 0 or equal to -1.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ return absl::OkStatus();
+}
+
+StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::CreateFromOptions(
+ const ImageSegmenterOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver) {
+ RETURN_IF_ERROR(SanityCheckOptions(options));
+
+ // Copy options to ensure the ExternalFile outlives the constructed object.
+ auto options_copy = absl::make_unique<ImageSegmenterOptions>(options);
+
+ ASSIGN_OR_RETURN(auto image_segmenter,
+ TaskAPIFactory::CreateFromExternalFileProto<ImageSegmenter>(
+ &options_copy->model_file_with_metadata(),
+ std::move(resolver), options_copy->num_threads()));
+
+ RETURN_IF_ERROR(image_segmenter->Init(std::move(options_copy)));
+
+ return image_segmenter;
+}
+
+absl::Status ImageSegmenter::Init(
+ std::unique_ptr<ImageSegmenterOptions> options) {
+ // Set options.
+ options_ = std::move(options);
+
+ // Perform pre-initialization actions (by default, sets the process engine for
+ // image pre-processing to kLibyuv as a sane default).
+ RETURN_IF_ERROR(PreInit());
+
+ // Sanity check and set inputs and outputs.
+ RETURN_IF_ERROR(CheckAndSetInputs());
+ RETURN_IF_ERROR(CheckAndSetOutputs());
+
+ // Initialize colored_labels_ once and for all.
+ RETURN_IF_ERROR(InitColoredLabels());
+
+ return absl::OkStatus();
+}
+
+absl::Status ImageSegmenter::PreInit() {
+ SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
+ return absl::OkStatus();
+}
+
+absl::Status ImageSegmenter::CheckAndSetOutputs() {
+ // First, sanity checks on the model itself.
+ const TfLiteEngine::Interpreter* interpreter = engine_->interpreter();
+
+ // Check the number of output tensors.
+ if (TfLiteEngine::OutputCount(interpreter) != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Image segmentation models are expected to have only 1 "
+ "output, found %d",
+ TfLiteEngine::OutputCount(interpreter)),
+ TfLiteSupportStatus::kInvalidNumOutputTensorsError);
+ }
+ const TfLiteTensor* output_tensor = TfLiteEngine::GetOutput(interpreter, 0);
+
+ // Check tensor dimensions.
+ if (output_tensor->dims->size != 4) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Output tensor is expected to have 4 dimensions, found %d.",
+ output_tensor->dims->size),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ if (output_tensor->dims->data[0] != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Expected batch size of 1, found %d.",
+ output_tensor->dims->data[0]),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ output_height_ = output_tensor->dims->data[1];
+ output_width_ = output_tensor->dims->data[2];
+ output_depth_ = output_tensor->dims->data[3];
+ if (output_depth_ > kMaxNumClasses) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Expected at most %d output classes, found %d",
+ kMaxNumClasses, output_depth_),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+
+ // Check tensor type.
+ if (output_tensor->type != kTfLiteFloat32 &&
+ output_tensor->type != kTfLiteUInt8) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Type mismatch for output tensor. Requested one of "
+ "these types: kTfLiteUint8/kTfLiteFloat32, got %s.",
+ TfLiteTypeGetName(output_tensor->type)),
+ TfLiteSupportStatus::kInvalidOutputTensorTypeError);
+ }
+ has_uint8_outputs_ = (output_tensor->type == kTfLiteUInt8);
+
+ // Build label map from metadata, if available.
+ const ModelMetadataExtractor* metadata_extractor =
+ engine_->metadata_extractor();
+ const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
+ output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata();
+ if (output_tensor_metadata != nullptr) {
+ // Check metadata consistency.
+ if (output_tensor_metadata->size() != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Mismatch between number of output tensors (1) and "
+ "output tensors metadata (%d).",
+ output_tensor_metadata->size()),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+ ASSIGN_OR_RETURN(
+ label_map_,
+ GetLabelMapIfAny(*metadata_extractor, *output_tensor_metadata->Get(0),
+ options_->display_names_locale()));
+ }
+
+ // If label map is still empty, build a default one.
+ if (label_map_.empty()) {
+ for (int class_index = 0; class_index < output_depth_; ++class_index) {
+ label_map_.emplace_back(LabelMapItem{});
+ }
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status ImageSegmenter::InitColoredLabels() {
+ for (int i = 0; i < label_map_.size(); ++i) {
+ Segmentation::ColoredLabel colored_label;
+ colored_label.set_r(kColorMap[3 * i]);
+ colored_label.set_g(kColorMap[3 * i + 1]);
+ colored_label.set_b(kColorMap[3 * i + 2]);
+ const LabelMapItem& item = label_map_[i];
+ if (!item.name.empty()) {
+ colored_label.set_class_name(item.name);
+ }
+ if (!item.display_name.empty()) {
+ colored_label.set_display_name(item.display_name);
+ }
+ colored_labels_.push_back(colored_label);
+ }
+ return absl::OkStatus();
+}
+
+StatusOr<SegmentationResult> ImageSegmenter::Segment(
+ const FrameBuffer& frame_buffer) {
+ BoundingBox roi;
+ roi.set_width(frame_buffer.dimension().width);
+ roi.set_height(frame_buffer.dimension().height);
+ return InferWithFallback(frame_buffer, roi);
+}
+
+StatusOr<SegmentationResult> ImageSegmenter::Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
+ if (output_tensors.size() != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Expected 1 output tensors, found %d",
+ output_tensors.size()));
+ }
+ const TfLiteTensor* output_tensor = output_tensors[0];
+
+ SegmentationResult result;
+ Segmentation* segmentation = result.add_segmentation();
+ *segmentation->mutable_colored_labels() = {colored_labels_.begin(),
+ colored_labels_.end()};
+
+ // The output tensor has orientation `frame_buffer.orientation()`, as it has
+ // been produced from the pre-processed frame.
+ FrameBuffer::Orientation tensor_orientation = frame_buffer.orientation();
+ // The output tensor always has size `output_width_ x output_height_`
+ FrameBuffer::Dimension tensor_dimension = {output_width_, output_height_};
+
+ // The masks to produce from the output tensor need to be re-oriented in the
+ // unrotated frame of reference coordinates system, i.e. kTopLeft.
+ FrameBuffer::Orientation mask_orientation =
+ FrameBuffer::Orientation::kTopLeft;
+ // They may thus have swapped dimensions compared to the tensor if the
+ // rotation is 90° or 270°.
+ FrameBuffer::Dimension mask_dimension(tensor_dimension);
+ if (RequireDimensionSwap(frame_buffer.orientation(),
+ FrameBuffer::Orientation::kTopLeft)) {
+ mask_dimension.Swap();
+ }
+ segmentation->set_width(mask_dimension.width);
+ segmentation->set_height(mask_dimension.height);
+
+ // XY coordinates in the tensor, to be computed from mask_x and mask_y below.
+ int tensor_x;
+ int tensor_y;
+
+ if (options_->output_type() == ImageSegmenterOptions::CATEGORY_MASK) {
+ auto* category_mask = segmentation->mutable_category_mask();
+ category_mask->resize(mask_dimension.width * mask_dimension.height);
+ int pixel_offset = 0;
+ for (int mask_y = 0; mask_y < mask_dimension.height; ++mask_y) {
+ for (int mask_x = 0; mask_x < mask_dimension.width; ++mask_x) {
+ // Compute the coordinates (tensor_x, tensor_y) in the tensor with
+ // tensor_orientation = frame_buffer.orientation() corresponding to the
+ // coordinates (mask_x, mask_y) in the mask being filled with
+ // mask_orientation = kTopLeft, i.e. the orientation of the unrotated
+ // frame of reference.
+ OrientCoordinates(/*from_x=*/mask_x,
+ /*from_y=*/mask_y,
+ /*from_orientation=*/mask_orientation,
+ /*to_orientation=*/tensor_orientation,
+ /*from_dimension=*/mask_dimension,
+ /*to_x=*/&tensor_x,
+ /*to_y=*/&tensor_y);
+ int class_index = 0;
+ float max_confidence = 0.0f;
+ for (int d = 0; d < output_depth_; ++d) {
+ const float confidence =
+ GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d);
+ if (confidence > max_confidence) {
+ class_index = d;
+ max_confidence = confidence;
+ }
+ }
+ (*category_mask)[pixel_offset++] = static_cast<char>(class_index);
+ }
+ }
+ } else if (options_->output_type() ==
+ ImageSegmenterOptions::CONFIDENCE_MASK) {
+ auto* confidence_masks = segmentation->mutable_confidence_masks();
+ for (int d = 0; d < output_depth_; ++d) {
+ confidence_masks->add_confidence_mask();
+ }
+ for (int mask_y = 0; mask_y < segmentation->height(); ++mask_y) {
+ for (int mask_x = 0; mask_x < segmentation->width(); ++mask_x) {
+ // See above.
+ OrientCoordinates(/*from_x=*/mask_x,
+ /*from_y=*/mask_y,
+ /*from_orientation=*/mask_orientation,
+ /*to_orientation=*/tensor_orientation,
+ /*from_dimension=*/mask_dimension,
+ /*to_x=*/&tensor_x,
+ /*to_y=*/&tensor_y);
+ for (int d = 0; d < output_depth_; ++d) {
+ confidence_masks->mutable_confidence_mask(d)->add_value(
+ GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d));
+ }
+ }
+ }
+ }
+
+ return result;
+}
+
+float ImageSegmenter::GetOutputConfidence(const TfLiteTensor& output_tensor,
+ int x, int y, int depth) {
+ int index = output_width_ * output_depth_ * y + output_depth_ * x + depth;
+ if (has_uint8_outputs_) {
+ const uint8* data = AssertAndReturnTypedTensor<uint8>(&output_tensor);
+ return output_tensor.params.scale *
+ (static_cast<int>(data[index]) - output_tensor.params.zero_point);
+ } else {
+ const float* data = AssertAndReturnTypedTensor<float>(&output_tensor);
+ return data[index];
+ }
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/image_segmenter.h b/tensorflow_lite_support/cc/task/vision/image_segmenter.h
new file mode 100644
index 00000000..663ddb70
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/image_segmenter.h
@@ -0,0 +1,172 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_
+
+#include <memory>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Performs segmentation on images.
+//
+// The API expects a TFLite model with optional, but strongly recommended,
+// TFLite Model Metadata.
+//
+// Input tensor:
+// (kTfLiteUInt8/kTfLiteFloat32)
+// - image input of size `[batch x height x width x channels]`.
+// - batch inference is not supported (`batch` is required to be 1).
+// - only RGB inputs are supported (`channels` is required to be 3).
+// - if type is kTfLiteFloat32, NormalizationOptions are required to be
+// attached to the metadata for input normalization.
+// Output tensor:
+// (kTfLiteUInt8/kTfLiteFloat32)
+// - tensor of size `[batch x mask_height x mask_width x num_classes]`, where
+// `batch` is required to be 1, `mask_width` and `mask_height` are the
+// dimensions of the segmentation masks produced by the model, and
+// `num_classes` is the number of classes supported by the model.
+// - optional (but recommended) label map(s) can be attached as
+// AssociatedFile-s with type TENSOR_AXIS_LABELS, containing one label per
+// line. The first such AssociatedFile (if any) is used to fill the
+// `class_name` field of the results. The `display_name` field is filled
+// from the AssociatedFile (if any) whose locale matches the
+// `display_names_locale` field of the `ImageSegmenterOptions` used at
+// creation time ("en" by default, i.e. English). If none of these are
+// available, only the `index` field of the results will be filled.
+//
+// An example of such model can be found at:
+// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1
+//
+// A CLI demo tool is available for easily trying out this API, and provides
+// example usage. See:
+// examples/task/vision/desktop/image_segmenter_demo.cc
+class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> {
+ public:
+ using BaseVisionTaskApi::BaseVisionTaskApi;
+
+ // Creates an ImageSegmenter from the provided options. A non-default
+ // OpResolver can be specified in order to support custom Ops or specify a
+ // subset of built-in Ops.
+ static tflite::support::StatusOr<std::unique_ptr<ImageSegmenter>>
+ CreateFromOptions(
+ const ImageSegmenterOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+
+ // Performs actual segmentation on the provided FrameBuffer.
+ //
+ // The FrameBuffer can be of any size and any of the supported formats, i.e.
+ // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before
+ // inference in order to (and in this order):
+ // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to
+ // the dimensions of the model input tensor,
+ // - convert it to the colorspace of the input tensor (i.e. RGB, which is the
+ // only supported colorspace for now),
+ // - rotate it according to its `Orientation` so that inference is performed
+ // on an "upright" image.
+ //
+ // IMPORTANT: the returned segmentation masks are not direcly suited for
+ // display, in particular:
+ // * they are relative to the unrotated input frame, i.e. *not* taking into
+ // account the `Orientation` flag of the input FrameBuffer,
+ // * their dimensions are intrinsic to the model, i.e. *not* dependent on the
+ // input FrameBuffer dimensions.
+ //
+ // Example of such post-processing, assuming:
+ // * an input FrameBuffer with width=640, height=480, orientation=kLeftBottom
+ // (i.e. the image will be rotated 90° clockwise during preprocessing to
+ // make it "upright"),
+ // * a model outputting masks of size 224x224.
+ // In order to be directly displayable on top of the input image assumed to
+ // be displayed *with* the `Orientation` flag taken into account according to
+ // the EXIF specification (http://jpegclub.org/exif_orientation.html), the
+ // masks need to be:
+ // * re-scaled to 640 x 480,
+ // * then rotated 90° clockwise.
+ tflite::support::StatusOr<SegmentationResult> Segment(
+ const FrameBuffer& frame_buffer);
+
+ protected:
+ // Post-processing to transform the raw model outputs into segmentation
+ // results.
+ tflite::support::StatusOr<SegmentationResult> Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
+
+ // Performs sanity checks on the provided ImageSegmenterOptions.
+ static absl::Status SanityCheckOptions(const ImageSegmenterOptions& options);
+
+ // Initializes the Segmenter from the provided ImageSegmenterOptions, whose
+ // ownership is transferred to this object.
+ absl::Status Init(std::unique_ptr<ImageSegmenterOptions> options);
+
+ // Performs pre-initialization actions.
+ virtual absl::Status PreInit();
+
+ // The options used for building this image segmenter.
+ std::unique_ptr<ImageSegmenterOptions> options_;
+
+ // The label map, extracted from the TFLite Model Metadata.
+ std::vector<LabelMapItem> label_map_;
+
+ private:
+ // Performs sanity checks on the model outputs and extracts their metadata.
+ absl::Status CheckAndSetOutputs();
+
+ // Initializes the colored labels list from `label_map_` and stores it in
+ // `colored_labels_`.
+ absl::Status InitColoredLabels();
+
+ // Returns the output confidence at coordinates {x, y, depth}, dequantizing
+ // on-the-fly if needed (i.e. if `has_uint8_outputs_` is true).
+ float GetOutputConfidence(const TfLiteTensor& output_tensor, int x, int y,
+ int depth);
+
+ // Prebuilt list of ColoredLabel attached to each Segmentation result. The
+ // i-th item in this list corresponds to the i-th label map item.
+ std::vector<Segmentation::ColoredLabel> colored_labels_;
+
+ // Whether the model features quantized inference type (QUANTIZED_UINT8). This
+ // is currently detected by checking if all output tensors data type is uint8.
+ bool has_uint8_outputs_;
+
+ // Expected output width.
+ int output_width_;
+ // Expected output height.
+ int output_height_;
+ // Expected output depth. This corresponds to the number of supported classes.
+ int output_depth_;
+};
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_
diff --git a/tensorflow_lite_support/cc/task/vision/object_detector.cc b/tensorflow_lite_support/cc/task/vision/object_detector.cc
new file mode 100644
index 00000000..22ec3019
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/object_detector.cc
@@ -0,0 +1,549 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/vision/object_detector.h"
+
+#include <algorithm>
+#include <limits>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
+#include "tensorflow_lite_support/cc/task/core/task_utils.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+namespace {
+
+using ::absl::StatusCode;
+using ::tflite::BoundingBoxProperties;
+using ::tflite::ContentProperties;
+using ::tflite::ContentProperties_BoundingBoxProperties;
+using ::tflite::EnumNameContentProperties;
+using ::tflite::ProcessUnit;
+using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions;
+using ::tflite::TensorMetadata;
+using ::tflite::metadata::ModelMetadataExtractor;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+using ::tflite::task::core::AssertAndReturnTypedTensor;
+using ::tflite::task::core::TaskAPIFactory;
+using ::tflite::task::core::TfLiteEngine;
+
+// The expected number of dimensions of the 4 output tensors, representing in
+// that order: locations, classes, scores, num_results.
+static constexpr int kOutputTensorsExpectedDims[4] = {3, 2, 2, 1};
+
+StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
+ const TensorMetadata& tensor_metadata) {
+ if (tensor_metadata.content() == nullptr ||
+ tensor_metadata.content()->content_properties() == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Expected BoundingBoxProperties for tensor %s, found none.",
+ tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"),
+ TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
+ }
+
+ ContentProperties type = tensor_metadata.content()->content_properties_type();
+ if (type != ContentProperties_BoundingBoxProperties) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Expected BoundingBoxProperties for tensor %s, found %s.",
+ tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
+ EnumNameContentProperties(type)),
+ TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
+ }
+
+ const BoundingBoxProperties* properties =
+ tensor_metadata.content()->content_properties_as_BoundingBoxProperties();
+
+ // Mobile SSD only supports "BOUNDARIES" bounding box type.
+ if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s",
+ tflite::EnumNameBoundingBoxType(properties->type())),
+ TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
+ }
+
+ // Mobile SSD only supports "RATIO" coordinates type.
+ if (properties->coordinate_type() != tflite::CoordinateType_RATIO) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Mobile SSD only supports CoordinateType RATIO, found %s",
+ tflite::EnumNameCoordinateType(properties->coordinate_type())),
+ TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
+ }
+
+ // Index is optional, but must contain 4 values if present.
+ if (properties->index() != nullptr && properties->index()->size() != 4) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Expected BoundingBoxProperties index to contain 4 values, found "
+ "%d",
+ properties->index()->size()),
+ TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
+ }
+
+ return properties;
+}
+
+StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny(
+ const ModelMetadataExtractor& metadata_extractor,
+ const TensorMetadata& tensor_metadata, absl::string_view locale) {
+ const std::string labels_filename =
+ ModelMetadataExtractor::FindFirstAssociatedFileName(
+ tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS);
+ if (labels_filename.empty()) {
+ return std::vector<LabelMapItem>();
+ }
+ ASSIGN_OR_RETURN(absl::string_view labels_file,
+ metadata_extractor.GetAssociatedFile(labels_filename));
+ const std::string display_names_filename =
+ ModelMetadataExtractor::FindFirstAssociatedFileName(
+ tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS,
+ locale);
+ absl::string_view display_names_file = nullptr;
+ if (!display_names_filename.empty()) {
+ ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile(
+ display_names_filename));
+ }
+ return BuildLabelMapFromFiles(labels_file, display_names_file);
+}
+
+StatusOr<float> GetScoreThreshold(
+ const ModelMetadataExtractor& metadata_extractor,
+ const TensorMetadata& tensor_metadata) {
+ ASSIGN_OR_RETURN(
+ const ProcessUnit* score_thresholding_process_unit,
+ metadata_extractor.FindFirstProcessUnit(
+ tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions));
+ if (score_thresholding_process_unit == nullptr) {
+ return std::numeric_limits<float>::lowest();
+ }
+ return score_thresholding_process_unit->options_as_ScoreThresholdingOptions()
+ ->global_score_threshold();
+}
+
+absl::Status SanityCheckOutputTensors(
+ const std::vector<const TfLiteTensor*>& output_tensors) {
+ if (output_tensors.size() != 4) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Expected 4 output tensors, found %d",
+ output_tensors.size()));
+ }
+
+ // Get number of results.
+ if (output_tensors[3]->dims->data[0] != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat(
+ "Expected tensor with dimensions [1] at index 3, found [%d]",
+ output_tensors[3]->dims->data[0]));
+ }
+ int num_results =
+ static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]);
+
+ // Check dimensions for the other tensors are correct.
+ if (output_tensors[0]->dims->data[0] != 1 ||
+ output_tensors[0]->dims->data[1] != num_results ||
+ output_tensors[0]->dims->data[2] != 4) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat(
+ "Expected locations tensor with dimensions [1,%d,4] at index 0, "
+ "found [%d,%d,%d].",
+ num_results, output_tensors[0]->dims->data[0],
+ output_tensors[0]->dims->data[1],
+ output_tensors[0]->dims->data[2]));
+ }
+ if (output_tensors[1]->dims->data[0] != 1 ||
+ output_tensors[1]->dims->data[1] != num_results) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat(
+ "Expected classes tensor with dimensions [1,%d] at index 1, "
+ "found [%d,%d].",
+ num_results, output_tensors[1]->dims->data[0],
+ output_tensors[1]->dims->data[1]));
+ }
+ if (output_tensors[2]->dims->data[0] != 1 ||
+ output_tensors[2]->dims->data[1] != num_results) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat(
+ "Expected scores tensor with dimensions [1,%d] at index 2, "
+ "found [%d,%d].",
+ num_results, output_tensors[2]->dims->data[0],
+ output_tensors[2]->dims->data[1]));
+ }
+
+ return absl::OkStatus();
+}
+
+} // namespace
+
+/* static */
+absl::Status ObjectDetector::SanityCheckOptions(
+ const ObjectDetectorOptions& options) {
+ if (!options.has_model_file_with_metadata()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Missing mandatory `model_file_with_metadata` field",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.max_results() == 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Invalid `max_results` option: value must be != 0",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.class_name_whitelist_size() > 0 &&
+ options.class_name_blacklist_size() > 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "`class_name_whitelist` and `class_name_blacklist` are mutually "
+ "exclusive options.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.num_threads() == 0 || options.num_threads() < -1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "`num_threads` must be greater than 0 or equal to -1.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ return absl::OkStatus();
+}
+
+/* static */
+StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::CreateFromOptions(
+ const ObjectDetectorOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver) {
+ RETURN_IF_ERROR(SanityCheckOptions(options));
+
+ // Copy options to ensure the ExternalFile outlives the constructed object.
+ auto options_copy = absl::make_unique<ObjectDetectorOptions>(options);
+
+ ASSIGN_OR_RETURN(auto object_detector,
+ TaskAPIFactory::CreateFromExternalFileProto<ObjectDetector>(
+ &options_copy->model_file_with_metadata(),
+ std::move(resolver), options_copy->num_threads()));
+
+ RETURN_IF_ERROR(object_detector->Init(std::move(options_copy)));
+
+ return object_detector;
+}
+
+absl::Status ObjectDetector::Init(
+ std::unique_ptr<ObjectDetectorOptions> options) {
+ // Set options.
+ options_ = std::move(options);
+
+ // Perform pre-initialization actions (by default, sets the process engine for
+ // image pre-processing to kLibyuv as a sane default).
+ RETURN_IF_ERROR(PreInit());
+
+ // Sanity check and set inputs and outputs.
+ RETURN_IF_ERROR(CheckAndSetInputs());
+ RETURN_IF_ERROR(CheckAndSetOutputs());
+
+ // Initialize class whitelisting/blacklisting, if any.
+ RETURN_IF_ERROR(CheckAndSetClassIndexSet());
+
+ return absl::OkStatus();
+}
+
+absl::Status ObjectDetector::PreInit() {
+ SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
+ return absl::OkStatus();
+}
+
+absl::Status ObjectDetector::CheckAndSetOutputs() {
+ // First, sanity checks on the model itself.
+ const TfLiteEngine::Interpreter* interpreter = engine_->interpreter();
+ // Check the number of output tensors.
+ if (TfLiteEngine::OutputCount(interpreter) != 4) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Mobile SSD models are expected to have exactly 4 "
+ "outputs, found %d",
+ TfLiteEngine::OutputCount(interpreter)),
+ TfLiteSupportStatus::kInvalidNumOutputTensorsError);
+ }
+ // Check tensor dimensions and batch size.
+ for (int i = 0; i < 4; ++i) {
+ const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, i);
+ if (tensor->dims->size != kOutputTensorsExpectedDims[i]) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Output tensor at index %d is expected to "
+ "have %d dimensions, found %d.",
+ i, kOutputTensorsExpectedDims[i], tensor->dims->size),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ if (tensor->dims->data[0] != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Expected batch size of 1, found %d.",
+ tensor->dims->data[0]),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ }
+
+ // Now, perform sanity checks and extract metadata.
+ const ModelMetadataExtractor* metadata_extractor =
+ engine_->metadata_extractor();
+ // Check that metadata is available.
+ if (metadata_extractor->GetModelMetadata() == nullptr ||
+ metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) {
+ return CreateStatusWithPayload(StatusCode::kInvalidArgument,
+ "Object detection models require TFLite "
+ "Model Metadata but none was found",
+ TfLiteSupportStatus::kMetadataNotFoundError);
+ }
+ // Check output tensor metadata is present and consistent with model.
+ auto output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata();
+ if (output_tensors_metadata == nullptr ||
+ output_tensors_metadata->size() != 4) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Mismatch between number of output tensors (4) and output tensors "
+ "metadata (%d).",
+ output_tensors_metadata == nullptr
+ ? 0
+ : output_tensors_metadata->size()),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+
+ // Extract mandatory BoundingBoxProperties for easier access at
+ // post-processing time, performing sanity checks on the fly.
+ ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
+ GetBoundingBoxProperties(*output_tensors_metadata->Get(0)));
+ if (bounding_box_properties->index() == nullptr) {
+ bounding_box_corners_order_ = {0, 1, 2, 3};
+ } else {
+ auto bounding_box_index = bounding_box_properties->index();
+ bounding_box_corners_order_ = {
+ bounding_box_index->Get(0),
+ bounding_box_index->Get(1),
+ bounding_box_index->Get(2),
+ bounding_box_index->Get(3),
+ };
+ }
+
+ // Build label map (if available) from metadata.
+ ASSIGN_OR_RETURN(
+ label_map_,
+ GetLabelMapIfAny(*metadata_extractor, *output_tensors_metadata->Get(1),
+ options_->display_names_locale()));
+
+ // Set score threshold.
+ if (options_->has_score_threshold()) {
+ score_threshold_ = options_->score_threshold();
+ } else {
+ ASSIGN_OR_RETURN(score_threshold_,
+ GetScoreThreshold(*metadata_extractor,
+ *output_tensors_metadata->Get(2)));
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status ObjectDetector::CheckAndSetClassIndexSet() {
+ // Exit early if no blacklist/whitelist.
+ if (options_->class_name_blacklist_size() == 0 &&
+ options_->class_name_whitelist_size() == 0) {
+ return absl::OkStatus();
+ }
+ // Label map is mandatory.
+ if (label_map_.empty()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Using `class_name_whitelist` or `class_name_blacklist` requires "
+ "labels to be present in the TFLite Model Metadata but none was found.",
+ TfLiteSupportStatus::kMetadataMissingLabelsError);
+ }
+
+ class_index_set_.is_whitelist = options_->class_name_whitelist_size() > 0;
+ const auto& class_names = class_index_set_.is_whitelist
+ ? options_->class_name_whitelist()
+ : options_->class_name_blacklist();
+ class_index_set_.values.clear();
+ for (const auto& class_name : class_names) {
+ int index = -1;
+ for (int i = 0; i < label_map_.size(); ++i) {
+ if (label_map_[i].name == class_name) {
+ index = i;
+ break;
+ }
+ }
+ // Ignore duplicate or unknown classes.
+ if (index < 0 || class_index_set_.values.contains(index)) {
+ continue;
+ }
+ class_index_set_.values.insert(index);
+ }
+
+ if (class_index_set_.values.empty()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Invalid class names specified via `class_name_%s`: none match "
+ "with model labels.",
+ class_index_set_.is_whitelist ? "whitelist" : "blacklist"),
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+
+ return absl::OkStatus();
+}
+
+StatusOr<DetectionResult> ObjectDetector::Detect(
+ const FrameBuffer& frame_buffer) {
+ BoundingBox roi;
+ roi.set_width(frame_buffer.dimension().width);
+ roi.set_height(frame_buffer.dimension().height);
+ // Rely on `Infer` instead of `InferWithFallback` as DetectionPostprocessing
+ // op doesn't support hardware acceleration at the time.
+ return Infer(frame_buffer, roi);
+}
+
+StatusOr<DetectionResult> ObjectDetector::Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) {
+ // Most of the checks here should never happen, as outputs have been validated
+ // at construction time. Checking nonetheless and returning internal errors if
+ // something bad happens.
+ RETURN_IF_ERROR(SanityCheckOutputTensors(output_tensors));
+
+ // Get number of available results.
+ const int num_results =
+ static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]);
+ // Compute number of max results to return.
+ const int max_results = options_->max_results() > 0
+ ? std::min(options_->max_results(), num_results)
+ : num_results;
+ // The dimensions of the upright (i.e. rotated according to its orientation)
+ // input frame.
+ FrameBuffer::Dimension upright_input_frame_dimensions =
+ frame_buffer.dimension();
+ if (RequireDimensionSwap(frame_buffer.orientation(),
+ FrameBuffer::Orientation::kTopLeft)) {
+ upright_input_frame_dimensions.Swap();
+ }
+
+ const float* locations = AssertAndReturnTypedTensor<float>(output_tensors[0]);
+ const float* classes = AssertAndReturnTypedTensor<float>(output_tensors[1]);
+ const float* scores = AssertAndReturnTypedTensor<float>(output_tensors[2]);
+ DetectionResult results;
+ for (int i = 0; i < num_results; ++i) {
+ const int class_index = static_cast<int>(classes[i]);
+ const float score = scores[i];
+ if (!IsClassIndexAllowed(class_index) || score < score_threshold_) {
+ continue;
+ }
+ Detection* detection = results.add_detections();
+ // Denormalize the bounding box cooordinates in the upright frame
+ // coordinates system, then rotate back from frame_buffer.orientation() to
+ // the unrotated frame of reference coordinates system (i.e. with
+ // orientation = kTopLeft).
+ *detection->mutable_bounding_box() = OrientAndDenormalizeBoundingBox(
+ /*from_left=*/locations[4 * i + bounding_box_corners_order_[0]],
+ /*from_top=*/locations[4 * i + bounding_box_corners_order_[1]],
+ /*from_right=*/locations[4 * i + bounding_box_corners_order_[2]],
+ /*from_bottom=*/locations[4 * i + bounding_box_corners_order_[3]],
+ /*from_orientation=*/frame_buffer.orientation(),
+ /*to_orientation=*/FrameBuffer::Orientation::kTopLeft,
+ /*from_dimension=*/upright_input_frame_dimensions);
+ Class* detection_class = detection->add_classes();
+ detection_class->set_index(class_index);
+ detection_class->set_score(score);
+ if (results.detections_size() == max_results) {
+ break;
+ }
+ }
+
+ if (!label_map_.empty()) {
+ RETURN_IF_ERROR(FillResultsFromLabelMap(&results));
+ }
+
+ return results;
+}
+
+bool ObjectDetector::IsClassIndexAllowed(int class_index) {
+ if (class_index_set_.values.empty()) {
+ return true;
+ }
+ if (class_index_set_.is_whitelist) {
+ return class_index_set_.values.contains(class_index);
+ } else {
+ return !class_index_set_.values.contains(class_index);
+ }
+}
+
+absl::Status ObjectDetector::FillResultsFromLabelMap(DetectionResult* result) {
+ for (int i = 0; i < result->detections_size(); ++i) {
+ Detection* detection = result->mutable_detections(i);
+ for (int j = 0; j < detection->classes_size(); ++j) {
+ Class* detection_class = detection->mutable_classes(j);
+ const int index = detection_class->index();
+ if (index >= label_map_.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Label map does not contain enough elements: model returned "
+ "class index %d but label map only contains %d elements.",
+ index, label_map_.size()),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+ std::string name = label_map_[index].name;
+ if (!name.empty()) {
+ detection_class->set_class_name(name);
+ }
+ std::string display_name = label_map_[index].display_name;
+ if (!display_name.empty()) {
+ detection_class->set_display_name(display_name);
+ }
+ }
+ }
+ return absl::OkStatus();
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/object_detector.h b/tensorflow_lite_support/cc/task/vision/object_detector.h
new file mode 100644
index 00000000..2bd220b3
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/object_detector.h
@@ -0,0 +1,186 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_OBJECT_DETECTOR_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_OBJECT_DETECTOR_H_
+
+#include <memory>
+
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/status.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Performs object detection on images.
+//
+// The API expects a TFLite model with mandatory TFLite Model Metadata.
+//
+// Input tensor:
+// (kTfLiteUInt8/kTfLiteFloat32)
+// - image input of size `[batch x height x width x channels]`.
+// - batch inference is not supported (`batch` is required to be 1).
+// - only RGB inputs are supported (`channels` is required to be 3).
+// - if type is kTfLiteFloat32, NormalizationOptions are required to be
+// attached to the metadata for input normalization.
+// Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e:
+// (kTfLiteFloat32)
+// - locations tensor of size `[num_results x 4]`, the inner array
+// representing bounding boxes in the form [top, left, right, bottom].
+// - BoundingBoxProperties are required to be attached to the metadata
+// and must specify type=BOUNDARIES and coordinate_type=RATIO.
+// (kTfLiteFloat32)
+// - classes tensor of size `[num_results]`, each value representing the
+// integer index of a class.
+// - optional (but recommended) label map(s) can be attached as
+// AssociatedFile-s with type TENSOR_VALUE_LABELS, containing one label per
+// line. The first such AssociatedFile (if any) is used to fill the
+// `class_name` field of the results. The `display_name` field is filled
+// from the AssociatedFile (if any) whose locale matches the
+// `display_names_locale` field of the `ObjectDetectorOptions` used at
+// creation time ("en" by default, i.e. English). If none of these are
+// available, only the `index` field of the results will be filled.
+// (kTfLiteFloat32)
+// - scores tensor of size `[num_results]`, each value representing the score
+// of the detected object.
+// (kTfLiteFloat32)
+// - integer num_results as a tensor of size `[1]`
+//
+// An example of such model can be found at:
+// https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1
+//
+// A CLI demo tool is available for easily trying out this API, and provides
+// example usage. See:
+// examples/task/vision/desktop/object_detector_demo.cc
+class ObjectDetector : public BaseVisionTaskApi<DetectionResult> {
+ public:
+ using BaseVisionTaskApi::BaseVisionTaskApi;
+
+ // Creates an ObjectDetector from the provided options. A non-default
+ // OpResolver can be specified in order to support custom Ops or specify a
+ // subset of built-in Ops.
+ static tflite::support::StatusOr<std::unique_ptr<ObjectDetector>>
+ CreateFromOptions(
+ const ObjectDetectorOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+
+ // Performs actual detection on the provided FrameBuffer.
+ //
+ // The FrameBuffer can be of any size and any of the supported formats, i.e.
+ // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed
+ // before inference in order to (and in this order):
+ // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to
+ // the dimensions of the model input tensor,
+ // - convert it to the colorspace of the input tensor (i.e. RGB, which is the
+ // only supported colorspace for now),
+ // - rotate it according to its `Orientation` so that inference is performed
+ // on an "upright" image.
+ //
+ // IMPORTANT: the returned bounding boxes are expressed in the unrotated input
+ // frame of reference coordinates system, i.e. in `[0, frame_buffer.width) x
+ // [0, frame_buffer.height)`, which are the dimensions of the underlying
+ // `frame_buffer` data before any `Orientation` flag gets applied.
+ //
+ // In particular, this implies that the returned bounding boxes may not be
+ // directly suitable for display if the input image is displayed *with* the
+ // `Orientation` flag taken into account according to the EXIF specification
+ // (http://jpegclub.org/exif_orientation.html): it may first need to be
+ // rotated. This is typically true when consuming camera frames on Android or
+ // iOS.
+ //
+ // For example, if the input `frame_buffer` has its `Orientation` flag set to
+ // `kLeftBottom` (i.e. the image will be rotated 90° clockwise during
+ // preprocessing to make it "upright"), then the same 90° clockwise rotation
+ // needs to be applied to the bounding box for display.
+ tflite::support::StatusOr<DetectionResult> Detect(
+ const FrameBuffer& frame_buffer);
+
+ protected:
+ // Post-processing to transform the raw model outputs into detection results.
+ tflite::support::StatusOr<DetectionResult> Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
+
+ // Performs sanity checks on the provided ObjectDetectorOptions.
+ static absl::Status SanityCheckOptions(const ObjectDetectorOptions& options);
+
+ // Initializes the ObjectDetector from the provided ObjectDetectorOptions,
+ // whose ownership is transferred to this object.
+ absl::Status Init(std::unique_ptr<ObjectDetectorOptions>);
+
+ // Performs pre-initialization actions.
+ virtual absl::Status PreInit();
+
+ private:
+ // Performs sanity checks on the model outputs and extracts their metadata.
+ absl::Status CheckAndSetOutputs();
+
+ // Performs sanity checks on the class whitelist/blacklist and forms the class
+ // index set.
+ absl::Status CheckAndSetClassIndexSet();
+
+ // Checks if the class at the provided index is allowed, i.e. whitelisted in
+ // case a whitelist is provided or not blacklisted if a blacklist is provided.
+ // Always returns true if no whitelist or blacklist were provided.
+ bool IsClassIndexAllowed(int class_index);
+
+ // Given a DetectionResult object containing class indices, fills the name and
+ // display name from the label map.
+ absl::Status FillResultsFromLabelMap(DetectionResult* result);
+
+ // The options used to build this ObjectDetector.
+ std::unique_ptr<ObjectDetectorOptions> options_;
+
+ // This is populated by reading the label files from the TFLite Model
+ // Metadata: if no such files are available, this is left empty and the
+ // ObjectDetector will only be able to populate the `index` field of the
+ // detection results `classes` field.
+ std::vector<LabelMapItem> label_map_;
+
+ // For each pack of 4 coordinates returned by the model, this denotes the
+ // order in which to get the left, top, right and bottom coordinates.
+ std::vector<unsigned int> bounding_box_corners_order_;
+
+ // Set of whitelisted or blacklisted class indices.
+ struct ClassIndexSet {
+ absl::flat_hash_set<int> values;
+ bool is_whitelist;
+ };
+ // Whitelisted or blacklisted class indices based on provided options at
+ // construction time. These are used to filter out results during
+ // post-processing.
+ ClassIndexSet class_index_set_;
+
+ // Score threshold. Detections with a confidence below this value are
+ // discarded. If none is provided via metadata or options, -FLT_MAX is set as
+ // default value.
+ float score_threshold_;
+};
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_OBJECT_DETECTOR_H_
diff --git a/tensorflow_lite_support/cc/task/vision/proto/BUILD b/tensorflow_lite_support/cc/task/vision/proto/BUILD
new file mode 100644
index 00000000..e294da76
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/BUILD
@@ -0,0 +1,208 @@
+load("//tensorflow_lite_support/cc/port:build_defs.bzl", "support_cc_proto_library")
+
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# Common vision protos.
+
+proto_library(
+ name = "bounding_box_proto",
+ srcs = ["bounding_box.proto"],
+)
+
+support_cc_proto_library(
+ name = "bounding_box_cc_proto",
+ srcs = ["bounding_box.proto"],
+ deps = [
+ ":bounding_box_proto",
+ ],
+)
+
+cc_library(
+ name = "bounding_box_proto_inc",
+ hdrs = ["bounding_box_proto_inc.h"],
+ deps = [":bounding_box_cc_proto"],
+)
+
+proto_library(
+ name = "class_proto",
+ srcs = ["class.proto"],
+)
+
+support_cc_proto_library(
+ name = "class_cc_proto",
+ srcs = ["class.proto"],
+ deps = [
+ ":class_proto",
+ ],
+)
+
+cc_library(
+ name = "class_proto_inc",
+ hdrs = ["class_proto_inc.h"],
+ deps = [":class_cc_proto"],
+)
+
+# ObjectDetector protos.
+
+proto_library(
+ name = "object_detector_options_proto",
+ srcs = ["object_detector_options.proto"],
+ deps = [
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto",
+ ],
+)
+
+support_cc_proto_library(
+ name = "object_detector_options_cc_proto",
+ srcs = ["object_detector_options.proto"],
+ cc_deps = ["//tensorflow_lite_support/cc/task/core/proto:external_file_cc_proto"],
+ deps = [
+ ":object_detector_options_proto",
+ ],
+)
+
+cc_library(
+ name = "object_detector_options_proto_inc",
+ hdrs = ["object_detector_options_proto_inc.h"],
+ deps = [
+ ":object_detector_options_cc_proto",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ ],
+)
+
+proto_library(
+ name = "detections_proto",
+ srcs = ["detections.proto"],
+ deps = [
+ ":bounding_box_proto",
+ ":class_proto",
+ ],
+)
+
+support_cc_proto_library(
+ name = "detections_cc_proto",
+ srcs = ["detections.proto"],
+ cc_deps = [
+ ":bounding_box_cc_proto",
+ ":class_cc_proto",
+ ],
+ deps = [
+ ":detections_proto",
+ ],
+)
+
+cc_library(
+ name = "detections_proto_inc",
+ hdrs = ["detections_proto_inc.h"],
+ deps = [
+ ":bounding_box_proto_inc",
+ ":class_proto_inc",
+ ":detections_cc_proto",
+ ],
+)
+
+# ImageClassifier protos.
+
+proto_library(
+ name = "image_classifier_options_proto",
+ srcs = ["image_classifier_options.proto"],
+ deps = [
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto",
+ ],
+)
+
+support_cc_proto_library(
+ name = "image_classifier_options_cc_proto",
+ srcs = ["image_classifier_options.proto"],
+ cc_deps = ["//tensorflow_lite_support/cc/task/core/proto:external_file_cc_proto"],
+ deps = [
+ ":image_classifier_options_proto",
+ ],
+)
+
+cc_library(
+ name = "image_classifier_options_proto_inc",
+ hdrs = ["image_classifier_options_proto_inc.h"],
+ deps = [
+ ":image_classifier_options_cc_proto",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ ],
+)
+
+proto_library(
+ name = "classifications_proto",
+ srcs = ["classifications.proto"],
+ deps = [
+ ":class_proto",
+ ],
+)
+
+support_cc_proto_library(
+ name = "classifications_cc_proto",
+ srcs = ["classifications.proto"],
+ cc_deps = [":class_cc_proto"],
+ deps = [
+ ":classifications_proto",
+ ],
+)
+
+cc_library(
+ name = "classifications_proto_inc",
+ hdrs = ["classifications_proto_inc.h"],
+ deps = [
+ ":class_proto_inc",
+ ":classifications_cc_proto",
+ ],
+)
+
+# ImageSegmenter protos.
+
+proto_library(
+ name = "image_segmenter_options_proto",
+ srcs = ["image_segmenter_options.proto"],
+ deps = [
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto",
+ ],
+)
+
+support_cc_proto_library(
+ name = "image_segmenter_options_cc_proto",
+ srcs = ["image_segmenter_options.proto"],
+ cc_deps = ["//tensorflow_lite_support/cc/task/core/proto:external_file_cc_proto"],
+ deps = [
+ ":image_segmenter_options_proto",
+ ],
+)
+
+cc_library(
+ name = "image_segmenter_options_proto_inc",
+ hdrs = ["image_segmenter_options_proto_inc.h"],
+ deps = [
+ ":image_segmenter_options_cc_proto",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ ],
+)
+
+proto_library(
+ name = "segmentations_proto",
+ srcs = ["segmentations.proto"],
+)
+
+support_cc_proto_library(
+ name = "segmentations_cc_proto",
+ srcs = ["segmentations.proto"],
+ deps = [
+ ":segmentations_proto",
+ ],
+)
+
+cc_library(
+ name = "segmentations_proto_inc",
+ hdrs = ["segmentations_proto_inc.h"],
+ deps = [":segmentations_cc_proto"],
+)
diff --git a/tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto b/tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto
new file mode 100644
index 00000000..4c2e1302
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto
@@ -0,0 +1,30 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto2";
+
+package tflite.task.vision;
+
+// An integer bounding box, axis aligned.
+message BoundingBox {
+ // The X coordinate of the top-left corner, in pixels.
+ optional int32 origin_x = 1;
+ // The Y coordinate of the top-left corner, in pixels.
+ optional int32 origin_y = 2;
+ // The width of the bounding box, in pixels.
+ optional int32 width = 3;
+ // The height of the bounding box, in pixels.
+ optional int32 height = 4;
+}
diff --git a/tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h
new file mode 100644
index 00000000..ef84b156
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h
@@ -0,0 +1,19 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_BOUNDING_BOX_PROTO_INC_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_BOUNDING_BOX_PROTO_INC_H_
+
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box.pb.h"
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_BOUNDING_BOX_PROTO_INC_H_
diff --git a/tensorflow_lite_support/cc/task/vision/proto/class.proto b/tensorflow_lite_support/cc/task/vision/proto/class.proto
new file mode 100644
index 00000000..19e8ac1d
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/class.proto
@@ -0,0 +1,36 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto2";
+
+package tflite.task.vision;
+
+// A single classification result.
+message Class {
+ // The index of the class in the corresponding label map, usually packed in
+ // the TFLite Model Metadata [1].
+ //
+ // [1]: https://www.tensorflow.org/lite/convert/metadata
+ optional int32 index = 1;
+ // The score for this class e.g. (but not necessarily) a probability in [0,1].
+ optional float score = 2;
+ // A human readable name of the class filled from the label map.
+ optional string display_name = 3;
+ // An ID for the class, not necessarily human-readable (e.g. a Google
+ // Knowledge Graph ID [1]), filled from the label map.
+ //
+ // [1]: https://developers.google.com/knowledge-graph
+ optional string class_name = 4;
+}
diff --git a/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h
new file mode 100644
index 00000000..2f9a409d
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h
@@ -0,0 +1,20 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_
+#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_
+
+#include "tensorflow_lite_support/cc/task/vision/proto/class.pb.h"
+#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_
diff --git a/tensorflow_lite_support/cc/task/vision/proto/classifications.proto b/tensorflow_lite_support/cc/task/vision/proto/classifications.proto
new file mode 100644
index 00000000..d3d9c66c
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/classifications.proto
@@ -0,0 +1,35 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto2";
+
+package tflite.task.vision;
+
+import "tensorflow_lite_support/cc/task/vision/proto/class.proto";
+
+// List of predicted classes (aka labels) for a given image classifier head.
+message Classifications {
+ // The array of predicted classes, usually sorted by descending scores (e.g.
+ // from high to low probability).
+ repeated Class classes = 1;
+ // The index of the image classifier head these classes refer to. This is
+ // useful for multi-head models.
+ optional int32 head_index = 2;
+}
+
+// Contains one set of results per image classifier head.
+message ClassificationResult {
+ repeated Classifications classifications = 1;
+}
diff --git a/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h
new file mode 100644
index 00000000..62a5f117
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h
@@ -0,0 +1,22 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_
+#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_
+
+#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
+
+#include "tensorflow_lite_support/cc/task/vision/proto/classifications.pb.h"
+#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_
diff --git a/tensorflow_lite_support/cc/task/vision/proto/detections.proto b/tensorflow_lite_support/cc/task/vision/proto/detections.proto
new file mode 100644
index 00000000..b600fc93
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/detections.proto
@@ -0,0 +1,53 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto2";
+
+package tflite.task.vision;
+
+import "tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto";
+import "tensorflow_lite_support/cc/task/vision/proto/class.proto";
+
+// A single detected object.
+message Detection {
+ // The bounding box.
+ //
+ // IMPORTANT: when using the Task APIs, the bounding box is expressed in the
+ // unrotated input frame of reference coordinates system, i.e. in `[0,
+ // frame_buffer.width) x [0, frame_buffer.height)`, which are the dimensions
+ // of the underlying `frame_buffer` data before any `Orientation` flag gets
+ // applied.
+ //
+ // In particular, this implies that the returned bounding boxes may not be
+ // directly suitable for display if the input image is displayed *with* the
+ // `Orientation` flag taken into account according to the EXIF specification
+ // (http://jpegclub.org/exif_orientation.html): it may first need to be
+ // rotated.
+ //
+ // For example, if the input `frame_buffer` has its `Orientation` flag set to
+ // `kLeftBottom` (i.e. the image will be rotated 90° clockwise during
+ // preprocessing to make it "upright"), then the same 90° clockwise rotation
+ // needs to be applied to the bounding box for display.
+ optional BoundingBox bounding_box = 2;
+ // The candidate classes, sorted by descending score.
+ repeated Class classes = 3;
+ // Reserved tags.
+ reserved 1, 4;
+}
+
+// List of detected objects.
+message DetectionResult {
+ repeated Detection detections = 1;
+}
diff --git a/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h
new file mode 100644
index 00000000..2b63cad6
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h
@@ -0,0 +1,23 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_
+#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_
+
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
+
+#include "tensorflow_lite_support/cc/task/vision/proto/detections.pb.h"
+#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_
diff --git a/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto b/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto
new file mode 100644
index 00000000..24cd85f3
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto
@@ -0,0 +1,67 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto2";
+
+package tflite.task.vision;
+
+import "tensorflow_lite_support/cc/task/core/proto/external_file.proto";
+
+// Options for setting up an ImageClassifier.
+// Next Id: 14
+message ImageClassifierOptions {
+ // The external model file, as a single standalone TFLite file. If it is
+ // packed with TFLite Model Metadata [1], those are used to populate e.g. the
+ // label map, score calibration and recommended score thresholds. Models
+ // without any such metadata or partial metadata are supported, but may result
+ // in the image classifier providing degraded functionality; typically, a
+ // model that doesn't contain any label map won't be able to return any class
+ // or display names but will be limited to returning class indices.
+ //
+ // [1]: https://www.tensorflow.org/lite/convert/metadata
+ optional core.ExternalFile model_file_with_metadata = 10;
+
+ // The locale to use for display names specified through the TFLite Model
+ // Metadata, if any. Defaults to English.
+ optional string display_names_locale = 11 [default = "en"];
+
+ // The maximum number of top-scored classification results to return. If < 0,
+ // all available results will be returned. If 0, an invalid argument error is
+ // returned.
+ optional int32 max_results = 2 [default = -1];
+
+ // Score threshold in [0,1), overrides the ones provided in the model metadata
+ // (if any). Results below this value are rejected.
+ optional float score_threshold = 3;
+
+ // Optional whitelist of class names. If non-empty, classifications whose
+ // class name is not in this set will be filtered out. Duplicate or unknown
+ // class names are ignored. Mutually exclusive with class_name_blacklist.
+ repeated string class_name_whitelist = 4;
+
+ // Optional blacklist of class names. If non-empty, classifications whose
+ // class name is in this set will be filtered out. Duplicate or unknown
+ // class names are ignored. Mutually exclusive with class_name_whitelist.
+ repeated string class_name_blacklist = 5;
+
+ // The number of threads to be used for TFLite ops that support
+ // multi-threading when running inference with CPU.
+ // num_threads should be greater than 0 or equal to -1. Setting num_threads to
+ // -1 has the effect to let TFLite runtime set the value.
+ optional int32 num_threads = 13 [default = -1];
+
+ // Reserved tags.
+ reserved 1, 6, 7, 8, 9, 12;
+}
diff --git a/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h
new file mode 100644
index 00000000..03dcd759
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h
@@ -0,0 +1,22 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_
+#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_
+
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+
+#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.pb.h"
+#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_
diff --git a/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto b/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto
new file mode 100644
index 00000000..3afed86a
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto
@@ -0,0 +1,61 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto2";
+
+package tflite.task.vision;
+
+import "tensorflow_lite_support/cc/task/core/proto/external_file.proto";
+
+// Options for setting up an ImageSegmenter.
+// Next Id: 8
+message ImageSegmenterOptions {
+ // The external model file, as a single standalone TFLite file. If it is
+ // packed with TFLite Model Metadata [1], those are used to populate label
+ // map. Models without any such metadata or partial metadata are supported,
+ // but may result in the segmenter providing degraded functionality;
+ // typically, a model that doesn't contain any label map won't be able to
+ // return any class or display names.
+ //
+ // [1]: https://www.tensorflow.org/lite/convert/metadata
+ optional core.ExternalFile model_file_with_metadata = 5;
+
+ // The locale to use for display names specified through the TFLite Model
+ // Metadata, if any. Defaults to English.
+ optional string display_names_locale = 6 [default = "en"];
+
+ // Output mask type. This allows specifying the type of post-processing to
+ // perform on the raw model results (see SegmentationResult proto for more).
+ enum OutputType {
+ UNSPECIFIED = 0;
+ // Gives a single output mask where each pixel represents the class which
+ // the pixel in the original image was predicted to belong to.
+ CATEGORY_MASK = 1;
+ // Gives a list of output masks where, for each mask, each pixel represents
+ // the prediction confidence, usually in the [0, 1] range.
+ CONFIDENCE_MASK = 2;
+ }
+ // Optional output mask type.
+ optional OutputType output_type = 3 [default = CATEGORY_MASK];
+
+ // The number of threads to be used for TFLite ops that support
+ // multi-threading when running inference with CPU.
+ // num_threads should be greater than 0 or equal to -1. Setting num_threads to
+ // -1 has the effect to let TFLite runtime set the value.
+ optional int32 num_threads = 7 [default = -1];
+
+ // Reserved tags.
+ reserved 1, 2, 4;
+}
diff --git a/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h
new file mode 100644
index 00000000..aaaecf36
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h
@@ -0,0 +1,22 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_
+#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_
+
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+
+#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.pb.h"
+#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_
diff --git a/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto b/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto
new file mode 100644
index 00000000..b55e9740
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto
@@ -0,0 +1,62 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto2";
+
+package tflite.task.vision;
+
+import "tensorflow_lite_support/cc/task/core/proto/external_file.proto";
+
+// Options for setting up an ObjectDetector.
+// Next Id: 8.
+message ObjectDetectorOptions {
+ // The external model file, as a single standalone TFLite file packed with
+ // TFLite Model Metadata [1]. Those are mandatory, and used to populate e.g.
+ // the label map and recommended score threshold.
+ //
+ // [1]: https://www.tensorflow.org/lite/convert/metadata
+ optional core.ExternalFile model_file_with_metadata = 1;
+
+ // The locale to use for display names specified through the TFLite Model
+ // Metadata, if any. Defaults to English.
+ optional string display_names_locale = 2 [default = "en"];
+
+ // The maximum number of top-scored detection results to return. If < 0, all
+ // available results will be returned. If 0, an invalid argument error is
+ // returned. Note that models may intrinsically be limited to returning a
+ // maximum number of results N: if the provided value here is above N, only N
+ // results will be returned.
+ optional int32 max_results = 3 [default = -1];
+
+ // Score threshold to override the one provided in the model metadata (if
+ // any). Detection results with a score below this value are rejected.
+ optional float score_threshold = 4;
+
+ // Optional whitelist of class names. If non-empty, detection results whose
+ // class name is not in this set will be filtered out. Duplicate or unknown
+ // class names are ignored. Mutually exclusive with class_name_blacklist.
+ repeated string class_name_whitelist = 5;
+
+ // Optional blacklist of class names. If non-empty, detection results whose
+ // class name is in this set will be filtered out. Duplicate or unknown
+ // class names are ignored. Mutually exclusive with class_name_whitelist.
+ repeated string class_name_blacklist = 6;
+
+ // The number of threads to be used for TFLite ops that support
+ // multi-threading when running inference with CPU.
+ // num_threads should be greater than 0 or equal to -1. Setting num_threads to
+ // -1 has the effect to let TFLite runtime set the value.
+ optional int32 num_threads = 7 [default = -1];
+}
diff --git a/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h
new file mode 100644
index 00000000..27898470
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h
@@ -0,0 +1,22 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_
+#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_
+
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+
+#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options.pb.h"
+#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_
diff --git a/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto b/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto
new file mode 100644
index 00000000..259bee81
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto
@@ -0,0 +1,109 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto2";
+
+package tflite.task.vision;
+
+// Results of performing image segmentation.
+// Note that at the time, a single `Segmentation` element is expected to be
+// returned; the field is made repeated for later extension to e.g. instance
+// segmentation models, which may return one segmentation per object.
+message SegmentationResult {
+ repeated Segmentation segmentation = 1;
+}
+
+// Next Id: 6
+message Segmentation {
+ // Confidence mask. This is a flattened 2D-array in row major order. For each
+ // pixel, the value indicates the prediction confidence usually in the [0, 1]
+ // range where higher values represent a stronger confidence. Ultimately this
+ // is model specific, and other range of values might be used.
+ message ConfidenceMask {
+ repeated float value = 1 [packed = true];
+ }
+
+ // List of confidence masks with respect to the model output depth (this depth
+ // represents how many classes are supported). Note: some models have a single
+ // class (e.g. a sky segmentation model) which turns into a single confidence
+ // mask in this list.
+ message ConfidenceMasks {
+ repeated ConfidenceMask confidence_mask = 1;
+ }
+
+ // IMPORTANT: segmentation masks are not direcly suited for display, in
+ // particular:
+ // * they are relative to the unrotated input frame, i.e. *not* taking into
+ // account the `Orientation` flag of the input FrameBuffer,
+ // * their dimensions are intrinsic to the model, i.e. *not* dependent on the
+ // input FrameBuffer dimensions.
+ //
+ // Example of such post-processing, assuming:
+ // * an input FrameBuffer with width=640, height=480, orientation=kLeftBottom
+ // (i.e. the image will be rotated 90° clockwise during preprocessing to
+ // make it "upright"),
+ // * a model outputting masks of size 224x224.
+ // In order to be directly displayable on top of the input image assumed to
+ // be displayed *with* the `Orientation` flag taken into account (according to
+ // the EXIF specification [1]), the masks need to be:
+ // * re-scaled to 640 x 480,
+ // * then rotated 90° clockwise.
+ //
+ // [1]: http://jpegclub.org/exif_orientation.html
+ oneof mask_oneof {
+ // Category mask. This is a flattened 2D-array of size `width` x `height`,
+ // in row major order. The value of each pixel in this mask represents the
+ // class to which the pixel belongs.
+ // See `colored_labels` for instructions on how to get pixel labels and
+ // display color.
+ bytes category_mask = 1;
+
+ // One confidence masks of size `width` x `height` for each of the supported
+ // classes. The value of each pixel in these masks represents the confidence
+ // score for this particular class.
+ // See `colored_labels` for instructions on how to get pixel labels and
+ // display color.
+ ConfidenceMasks confidence_masks = 4;
+ }
+ // The width of the mask. This is an intrinsic parameter of the model being
+ // used, and does not depend on the input image dimensions.
+ optional int32 width = 2;
+ // The height of the mask. This is an intrinsic parameter of the model being
+ // used, and does not depend on the input image dimensions.
+ optional int32 height = 3;
+
+ // Defines a label associated with an RGB color, for display purposes.
+ message ColoredLabel {
+ // The RGB color components for the label, in the [0, 255] range.
+ optional uint32 r = 1;
+ optional uint32 g = 2;
+ optional uint32 b = 3;
+ // The class name, as provided in the label map packed in the TFLite Model
+ // Metadata.
+ optional string class_name = 4;
+ // The display name, as provided in the label map (if available) packed in
+ // the TFLite Model Metadata. See `display_names_locale` field in
+ // ImageSegmenterOptions.
+ optional string display_name = 5;
+ }
+
+ // The list of colored labels for all the supported categories. Depending on
+ // which is present, this list is in 1:1 correspondence with:
+ // * `category_mask` pixel values, i.e. a pixel with value `i` is
+ // associated with `colored_labels[i]`,
+ // * `confidence_masks` indices, i.e. `confidence_masks[i]` is associated with
+ // `colored_labels[i]`.
+ repeated ColoredLabel colored_labels = 5;
+}
diff --git a/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h
new file mode 100644
index 00000000..cfc96e69
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h
@@ -0,0 +1,19 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_
+#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_
+
+#include "tensorflow_lite_support/cc/task/vision/proto/segmentations.pb.h"
+#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_
diff --git a/tensorflow_lite_support/cc/task/vision/utils/BUILD b/tensorflow_lite_support/cc/task/vision/utils/BUILD
new file mode 100644
index 00000000..89951451
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/BUILD
@@ -0,0 +1,109 @@
+package(
+ default_visibility = [
+ "//tensorflow_lite_support:users",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "score_calibration",
+ srcs = ["score_calibration.cc"],
+ hdrs = ["score_calibration.h"],
+ deps = [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/vision/core:label_map_item",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
+ name = "frame_buffer_common_utils",
+ srcs = [
+ "frame_buffer_common_utils.cc",
+ ],
+ hdrs = [
+ "frame_buffer_common_utils.h",
+ "frame_buffer_utils_interface.h",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "frame_buffer_utils",
+ srcs = [
+ "frame_buffer_utils.cc",
+ ],
+ hdrs = [
+ "frame_buffer_utils.h",
+ ],
+ deps = [
+ ":frame_buffer_common_utils",
+ ":libyuv_frame_buffer_utils",
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:variant",
+ "@org_tensorflow//tensorflow/lite/kernels:op_macros",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:compatibility",
+ ],
+)
+
+cc_library(
+ name = "libyuv_frame_buffer_utils",
+ srcs = ["libyuv_frame_buffer_utils.cc"],
+ hdrs = ["libyuv_frame_buffer_utils.h"],
+ deps = [
+ ":frame_buffer_common_utils",
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@libyuv",
+ ],
+)
+
+cc_library(
+ name = "image_tensor_specs",
+ srcs = ["image_tensor_specs.cc"],
+ hdrs = ["image_tensor_specs.h"],
+ deps = [
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:tflite_engine",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/types:optional",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ ],
+)
diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
new file mode 100644
index 00000000..fa9b05f5
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
@@ -0,0 +1,428 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+namespace {
+
+using ::tflite::support::StatusOr;
+
+constexpr int kRgbaChannels = 4;
+constexpr int kRgbChannels = 3;
+constexpr int kGrayChannel = 1;
+
+// Creates a FrameBuffer from raw NV12 buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromNV12RawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ const std::vector<FrameBuffer::Plane> planes_nv12 = {
+ {input, /*stride=*/{dimension.width, kGrayChannel}},
+ {input + dimension.Size(), /*stride=*/{dimension.width, 2}}};
+ return FrameBuffer::Create(planes_nv12, dimension, FrameBuffer::Format::kNV12,
+ orientation, timestamp);
+}
+
+// Creates a FrameBuffer from raw NV21 buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromNV21RawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ FrameBuffer::Plane input_plane = {/*buffer=*/input,
+ /*stride=*/{dimension.width, kGrayChannel}};
+ return FrameBuffer::Create({input_plane}, dimension,
+ FrameBuffer::Format::kNV21, orientation,
+ timestamp);
+}
+
+// Indicates whether the given buffers have the same dimensions.
+bool AreBufferDimsEqual(const FrameBuffer& buffer1,
+ const FrameBuffer& buffer2) {
+ return buffer1.dimension() == buffer2.dimension();
+}
+
+// Indicates whether the given buffers formats are compatible. Same formats are
+// compatible and all YUV family formats (e.g. NV21, NV12, YV12, YV21, etc) are
+// compatible.
+bool AreBufferFormatsCompatible(const FrameBuffer& buffer1,
+ const FrameBuffer& buffer2) {
+ switch (buffer1.format()) {
+ case FrameBuffer::Format::kRGBA:
+ case FrameBuffer::Format::kRGB:
+ return (buffer2.format() == FrameBuffer::Format::kRGBA ||
+ buffer2.format() == FrameBuffer::Format::kRGB);
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return (buffer2.format() == FrameBuffer::Format::kNV12 ||
+ buffer2.format() == FrameBuffer::Format::kNV21 ||
+ buffer2.format() == FrameBuffer::Format::kYV12 ||
+ buffer2.format() == FrameBuffer::Format::kYV21);
+ case FrameBuffer::Format::kGRAY:
+ default:
+ return buffer1.format() == buffer2.format();
+ }
+}
+
+} // namespace
+
+// Miscellaneous Methods
+// -----------------------------------------------------------------
+int GetFrameBufferByteSize(FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format) {
+ switch (format) {
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return /*y plane*/ dimension.Size() +
+ /*uv plane*/ ((static_cast<float>(dimension.width + 1) / 2) *
+ (static_cast<float>(dimension.height + 1) / 2) * 2);
+ case FrameBuffer::Format::kRGB:
+ return dimension.Size() * 3;
+ case FrameBuffer::Format::kRGBA:
+ return dimension.Size() * 4;
+ case FrameBuffer::Format::kGRAY:
+ return dimension.Size();
+ default:
+ return 0;
+ }
+}
+
+StatusOr<int> GetPixelStrides(FrameBuffer::Format format) {
+ switch (format) {
+ case FrameBuffer::Format::kGRAY:
+ return kGrayPixelBytes;
+ case FrameBuffer::Format::kRGB:
+ return kRgbPixelBytes;
+ case FrameBuffer::Format::kRGBA:
+ return kRgbaPixelBytes;
+ default:
+ return absl::InvalidArgumentError(absl::StrFormat(
+ "GetPixelStrides does not support format: %i.", format));
+ }
+}
+
+StatusOr<const uint8*> GetUvRawBuffer(const FrameBuffer& buffer) {
+ if (buffer.format() != FrameBuffer::Format::kNV12 &&
+ buffer.format() != FrameBuffer::Format::kNV21) {
+ return absl::InvalidArgumentError(
+ "Only support getting biplanar UV buffer from NV12/NV21 frame buffer.");
+ }
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ const uint8* uv_buffer = buffer.format() == FrameBuffer::Format::kNV12
+ ? yuv_data.u_buffer
+ : yuv_data.v_buffer;
+ return uv_buffer;
+}
+
+StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
+ FrameBuffer::Dimension dimension, FrameBuffer::Format format) {
+ if (dimension.width <= 0 || dimension.height <= 0) {
+ return absl::InvalidArgumentError(
+ absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width,
+ dimension.height));
+ }
+ switch (format) {
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return FrameBuffer::Dimension{(dimension.width + 1) / 2,
+ (dimension.height + 1) / 2};
+ default:
+ return absl::InvalidArgumentError(
+ absl::StrFormat("Input format is not YUV-like: %i.", format));
+ }
+}
+
+FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1) {
+ return {x1 - x0 + 1, y1 - y0 + 1};
+}
+
+// Validation Methods
+// -----------------------------------------------------------------
+
+absl::Status ValidateBufferPlaneMetadata(const FrameBuffer& buffer) {
+ if (buffer.plane_count() < 1) {
+ return absl::InvalidArgumentError(
+ "There must be at least 1 plane specified.");
+ }
+
+ for (int i = 0; i < buffer.plane_count(); i++) {
+ if (buffer.plane(i).stride.row_stride_bytes == 0 ||
+ buffer.plane(i).stride.pixel_stride_bytes == 0) {
+ return absl::InvalidArgumentError("Invalid stride information.");
+ }
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status ValidateBufferFormat(const FrameBuffer& buffer) {
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kGRAY:
+ case FrameBuffer::Format::kRGB:
+ case FrameBuffer::Format::kRGBA:
+ if (buffer.plane_count() == 1) return absl::OkStatus();
+ return absl::InvalidArgumentError(
+ "Plane count must be 1 for grayscale and RGB[a] buffers.");
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kYV21:
+ case FrameBuffer::Format::kYV12:
+ return absl::OkStatus();
+ default:
+ return absl::InternalError(
+ absl::StrFormat("Unsupported buffer format: %i.", buffer.format()));
+ }
+}
+
+absl::Status ValidateBufferFormats(const FrameBuffer& buffer1,
+ const FrameBuffer& buffer2) {
+ RETURN_IF_ERROR(ValidateBufferFormat(buffer1));
+ RETURN_IF_ERROR(ValidateBufferFormat(buffer2));
+ return absl::OkStatus();
+}
+
+absl::Status ValidateResizeBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer) {
+ bool valid_format = false;
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kGRAY:
+ case FrameBuffer::Format::kRGB:
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ valid_format = (buffer.format() == output_buffer.format());
+ break;
+ case FrameBuffer::Format::kRGBA:
+ valid_format = (output_buffer.format() == FrameBuffer::Format::kRGBA ||
+ output_buffer.format() == FrameBuffer::Format::kRGB);
+ break;
+ default:
+ return absl::InternalError(
+ absl::StrFormat("Unsupported buffer format: %i.", buffer.format()));
+ }
+ if (!valid_format) {
+ return absl::InvalidArgumentError(
+ "Input and output buffer formats must match.");
+ }
+ return ValidateBufferFormats(buffer, output_buffer);
+}
+
+absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer,
+ int angle_deg) {
+ if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
+ return absl::InvalidArgumentError(
+ "Input and output buffer formats must match.");
+ }
+
+ const bool is_dimension_change = (angle_deg / 90) % 2 == 1;
+ const bool are_dimensions_rotated =
+ (buffer.dimension().width == output_buffer.dimension().height) &&
+ (buffer.dimension().height == output_buffer.dimension().width);
+ const bool are_dimensions_equal =
+ buffer.dimension() == output_buffer.dimension();
+
+ if (angle_deg >= 360 || angle_deg <= 0 || angle_deg % 90 != 0) {
+ return absl::InvalidArgumentError(
+ "Rotation angle must be between 0 and 360, in multiples of 90 "
+ "degrees.");
+ } else if ((is_dimension_change && !are_dimensions_rotated) ||
+ (!is_dimension_change && !are_dimensions_equal)) {
+ return absl::InvalidArgumentError(
+ "Output buffer has invalid dimensions for rotation.");
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer, int x0,
+ int y0, int x1, int y1) {
+ if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
+ return absl::InvalidArgumentError(
+ "Input and output buffer formats must match.");
+ }
+
+ bool is_buffer_size_valid =
+ ((x1 < buffer.dimension().width) && y1 < buffer.dimension().height);
+ bool are_points_valid = (x0 >= 0) && (y0 >= 0) && (x1 >= x0) && (y1 >= y0);
+
+ if (!is_buffer_size_valid || !are_points_valid) {
+ return absl::InvalidArgumentError("Invalid crop coordinates.");
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer) {
+ if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
+ return absl::InvalidArgumentError(
+ "Input and output buffer formats must match.");
+ }
+ return AreBufferDimsEqual(buffer, output_buffer)
+ ? absl::OkStatus()
+ : absl::InvalidArgumentError(
+ "Input and output buffers must have the same dimensions.");
+}
+
+absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
+ FrameBuffer::Format to_format) {
+ if (from_format == to_format) {
+ return absl::InvalidArgumentError("Formats must be different.");
+ }
+
+ switch (from_format) {
+ case FrameBuffer::Format::kGRAY:
+ return absl::InvalidArgumentError(
+ "Grayscale format does not convert to other formats.");
+ case FrameBuffer::Format::kRGB:
+ if (to_format == FrameBuffer::Format::kRGBA) {
+ return absl::InvalidArgumentError(
+ "RGB format does not convert to RGBA");
+ }
+ return absl::OkStatus();
+ case FrameBuffer::Format::kRGBA:
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return absl::OkStatus();
+ default:
+ return absl::InternalError(
+ absl::StrFormat("Unsupported buffer format: %i.", from_format));
+ }
+}
+
+// Creation Methods
+// -----------------------------------------------------------------
+
+// Creates a FrameBuffer from raw RGBA buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ FrameBuffer::Plane input_plane = {
+ /*buffer=*/input,
+ /*stride=*/{dimension.width * kRgbaChannels, kRgbaChannels}};
+ return FrameBuffer::Create({input_plane}, dimension,
+ FrameBuffer::Format::kRGBA, orientation,
+ timestamp);
+}
+
+// Creates a FrameBuffer from raw RGB buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ FrameBuffer::Plane input_plane = {
+ /*buffer=*/input,
+ /*stride=*/{dimension.width * kRgbChannels, kRgbChannels}};
+ return FrameBuffer::Create({input_plane}, dimension,
+ FrameBuffer::Format::kRGB, orientation, timestamp);
+}
+
+// Creates a FrameBuffer from raw grayscale buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ FrameBuffer::Plane input_plane = {/*buffer=*/input,
+ /*stride=*/{dimension.width, kGrayChannel}};
+ return FrameBuffer::Create({input_plane}, dimension,
+ FrameBuffer::Format::kGRAY, orientation,
+ timestamp);
+}
+
+// Creates a FrameBuffer from raw YUV buffer and passing arguments.
+StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
+ const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
+ FrameBuffer::Format format, FrameBuffer::Dimension dimension,
+ int row_stride_y, int row_stride_uv, int pixel_stride_uv,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ const int pixel_stride_y = 1;
+ std::vector<FrameBuffer::Plane> planes;
+ if (format == FrameBuffer::Format::kNV21 ||
+ format == FrameBuffer::Format::kYV12) {
+ planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}},
+ {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}},
+ {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}};
+ } else if (format == FrameBuffer::Format::kNV12 ||
+ format == FrameBuffer::Format::kYV21) {
+ planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}},
+ {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}},
+ {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}};
+ } else {
+ return absl::InvalidArgumentError(
+ absl::StrFormat("Input format is not YUV-like: %i.", format));
+ }
+ return FrameBuffer::Create(planes, dimension, format, orientation, timestamp);
+}
+
+StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
+ const uint8* buffer, FrameBuffer::Dimension dimension,
+ const FrameBuffer::Format target_format,
+ FrameBuffer::Orientation orientation, absl::Time timestamp) {
+ switch (target_format) {
+ case FrameBuffer::Format::kNV12:
+ return CreateFromNV12RawBuffer(buffer, dimension, orientation, timestamp);
+ case FrameBuffer::Format::kNV21:
+ return CreateFromNV21RawBuffer(buffer, dimension, orientation, timestamp);
+ case FrameBuffer::Format::kYV12: {
+ ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension,
+ GetUvPlaneDimension(dimension, target_format));
+ return CreateFromYuvRawBuffer(
+ /*y_plane=*/buffer,
+ /*u_plane=*/buffer + dimension.Size() + uv_dimension.Size(),
+ /*v_plane=*/buffer + dimension.Size(), target_format, dimension,
+ /*row_stride_y=*/dimension.width, uv_dimension.width,
+ /*pixel_stride_uv=*/1, orientation, timestamp);
+ }
+ case FrameBuffer::Format::kYV21: {
+ ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension,
+ GetUvPlaneDimension(dimension, target_format));
+ return CreateFromYuvRawBuffer(
+ /*y_plane=*/buffer, /*u_plane=*/buffer + dimension.Size(),
+ /*v_plane=*/buffer + dimension.Size() + uv_dimension.Size(),
+ target_format, dimension, /*row_stride_y=*/dimension.width,
+ uv_dimension.width,
+ /*pixel_stride_uv=*/1, orientation, timestamp);
+ }
+ case FrameBuffer::Format::kRGBA:
+ return CreateFromRgbaRawBuffer(buffer, dimension, orientation, timestamp);
+ case FrameBuffer::Format::kRGB:
+ return CreateFromRgbRawBuffer(buffer, dimension, orientation, timestamp);
+ case FrameBuffer::Format::kGRAY:
+ return CreateFromGrayRawBuffer(buffer, dimension, orientation, timestamp);
+ default:
+
+ return absl::InternalError(
+ absl::StrFormat("Unsupported buffer format: %i.", target_format));
+ }
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
new file mode 100644
index 00000000..e250d154
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h
@@ -0,0 +1,143 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_COMMON_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_COMMON_UTILS_H_
+
+#include <memory>
+
+#include "absl/status/status.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+constexpr int kRgbaPixelBytes = 4, kRgbPixelBytes = 3, kGrayPixelBytes = 1;
+
+// Miscellaneous Methods
+// -----------------------------------------------------------------
+
+// Returns the frame buffer size in bytes based on the input format and
+// dimensions. GRAY, YV12/YV21 are in the planar formats, NV12/NV21 are in the
+// semi-planar formats with the interleaved UV planes. RGB/RGBA are in the
+// interleaved format.
+int GetFrameBufferByteSize(FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format);
+
+// Returns pixel stride info for kGRAY, kRGB, kRGBA formats.
+tflite::support::StatusOr<int> GetPixelStrides(FrameBuffer::Format format);
+
+// Returns the biplanar UV raw buffer for NV12/NV21 frame buffer.
+tflite::support::StatusOr<const uint8*> GetUvRawBuffer(
+ const FrameBuffer& buffer);
+
+// Returns U or V plane dimension with the given buffer `dimension` and
+// `format`. Only supports NV12/NV21/YV12/YV21 formats. Returns
+// InvalidArgumentError if 'dimension' is invalid or 'format' is other than the
+// supported formats. This method assums the UV plane share the same dimension,
+// especially for the YV12 / YV21 formats.
+tflite::support::StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
+ FrameBuffer::Dimension dimension, FrameBuffer::Format format);
+
+// Returns crop dimension based on crop start and end points.
+FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1);
+
+// Validation Methods
+// -----------------------------------------------------------------
+
+// Validates that the given buffer has the correct metadata. Returns error
+// state when any buffer has missing stride info.
+absl::Status ValidateBufferPlaneMetadata(const FrameBuffer& buffer);
+
+// Validates that the given buffer has the correct format for its configuration.
+absl::Status ValidateBufferFormat(const FrameBuffer& buffer);
+
+// Validates that the given buffers have the correct format for their
+// configuration.
+absl::Status ValidateBufferFormats(const FrameBuffer& buffer1,
+ const FrameBuffer& buffer2);
+
+// Validates the given inputs for resizing `buffer`.
+absl::Status ValidateResizeBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer);
+
+// Validates the given inputs for rotating `buffer`.
+absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer,
+ int angle_deg);
+
+// Validates the given inputs for cropping `buffer`.
+//
+// (x0, y0) represents the top-left point of the buffer.
+// (x1, y1) represents the bottom-right point of the buffer.
+absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer, int x0,
+ int y0, int x1, int y1);
+
+// Validates the given inputs for flipping `buffer` horizontally or vertically.
+absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer);
+
+// Validates that `from_format` can be converted to `to_format`.
+//
+// The given formats must not be equal.
+absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
+ FrameBuffer::Format to_format);
+
+// Creation Methods
+// -----------------------------------------------------------------
+
+// Creates a FrameBuffer from raw RGBA buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
+ absl::Time timestamp = absl::Now());
+
+// Creates a FrameBuffer from raw RGB buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
+ absl::Time timestamp = absl::Now());
+
+// Creates a FrameBuffer from raw grayscale buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
+ absl::Time timestamp = absl::Now());
+
+// Creates a FrameBuffer from raw YUV buffer and passing arguments.
+tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
+ const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
+ FrameBuffer::Format format, FrameBuffer::Dimension dimension,
+ int row_stride_y, int row_stride_uv, int pixel_stride_uv,
+ FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
+ absl::Time timestamp = absl::Now());
+
+// Creates an instance of FrameBuffer from raw buffer and passing arguments.
+tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
+ const uint8* buffer, FrameBuffer::Dimension dimension,
+ FrameBuffer::Format target_format,
+ FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft,
+ absl::Time timestamp = absl::Now());
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_COMMON_UTILS_H_
diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
new file mode 100644
index 00000000..9b9d830e
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc
@@ -0,0 +1,619 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
+
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/lite/kernels/op_macros.h"
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+namespace {
+
+// Exif grouping to help determine rotation and flipping neededs between
+// different orientations.
+constexpr int kExifGroup[] = {1, 6, 3, 8, 2, 5, 4, 7};
+// Exif group size.
+constexpr int kExifGroupSize = 4;
+
+// Returns orientation position in Exif group.
+static int GetOrientationIndex(FrameBuffer::Orientation orientation) {
+ const int* index = std::find(kExifGroup, kExifGroup + kExifGroupSize * 2,
+ static_cast<int>(orientation));
+ if (index < kExifGroup + kExifGroupSize * 2) {
+ return std::distance(kExifGroup, index);
+ }
+ return -1;
+}
+
+// Returns the coordinates of `box` respect to its containing image (dimension
+// defined by `width` and `height`) orientation change. The `angle` is defined
+// in counterclockwise degree in one of the values [0, 90, 180, 270].
+//
+// The below diagrams illustrate calling this method with 90 CCW degree.
+//
+// The [1]-[4] denotes image corners and 1 - 4 denotes the box corners. The *
+// denotes the current origin.
+//
+// width
+// [1]*----------------[2]
+// | |
+// | |
+// | 1*-----2 | height
+// | | box | |
+// | 3------4 |
+// [3]-----------------[4]
+//
+// When rotate the above image by 90 CCW degree, the origin also changes
+// respects to its containing coordinate space.
+//
+// height
+// [2]*----------[4]
+// | |
+// | 2*---4 |
+// | |box | |
+// | | | | width
+// | 1----3 |
+// | |
+// | |
+// | |
+// [1]-----------[3]
+//
+// The origin is always defined by the top left corner. After rotation, the
+// box origin changed from 1 to 2.
+// The new box origin is (x:box.origin_y, y:width - (box.origin_x + box.width).
+// The new box dimension is (w: box.height, h: box.width).
+//
+static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle,
+ FrameBuffer::Dimension frame_dimension) {
+ int rx = box.origin_x(), ry = box.origin_y(), rw = box.width(),
+ rh = box.height();
+ const int box_right_bound =
+ frame_dimension.width - (box.origin_x() + box.width());
+ const int box_bottom_bound =
+ frame_dimension.height - (box.origin_y() + box.height());
+ switch (angle) {
+ case 90:
+ rx = box.origin_y();
+ ry = box_right_bound;
+ using std::swap;
+ swap(rw, rh);
+ break;
+ case 180:
+ rx = box_right_bound;
+ ry = box_bottom_bound;
+ break;
+ case 270:
+ rx = box_bottom_bound;
+ ry = box.origin_x();
+ using std::swap;
+ swap(rw, rh);
+ break;
+ }
+ BoundingBox result;
+ result.set_origin_x(rx);
+ result.set_origin_y(ry);
+ result.set_width(rw);
+ result.set_height(rh);
+ return result;
+}
+
+// Returns the input coordinates with respect to its containing image (dimension
+// defined by `width` and `height`) orientation change. The `angle` is defined
+// in counterclockwise degree in one of the values [0, 90, 180, 270].
+//
+// See `RotateBoundingBox` above for more details.
+static void RotateCoordinates(int from_x, int from_y, int angle,
+ const FrameBuffer::Dimension& frame_dimension,
+ int* to_x, int* to_y) {
+ switch (angle) {
+ case 0:
+ *to_x = from_x;
+ *to_y = from_y;
+ break;
+ case 90:
+ *to_x = from_y;
+ *to_y = frame_dimension.width - from_x - 1;
+ break;
+ case 180:
+ *to_x = frame_dimension.width - from_x - 1;
+ *to_y = frame_dimension.height - from_y - 1;
+ break;
+ case 270:
+ *to_x = frame_dimension.height - from_y - 1;
+ *to_y = from_x;
+ break;
+ }
+}
+
+} // namespace
+
+int GetBufferByteSize(FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format) {
+ return GetFrameBufferByteSize(dimension, format);
+}
+
+FrameBufferUtils::FrameBufferUtils(ProcessEngine engine) {
+ switch (engine) {
+ case ProcessEngine::kLibyuv:
+ utils_ = absl::make_unique<LibyuvFrameBufferUtils>();
+ break;
+ default:
+ TF_LITE_FATAL(
+ absl::StrFormat("Unexpected ProcessEngine: %d.", engine).c_str());
+ }
+}
+
+BoundingBox OrientBoundingBox(const BoundingBox& from_box,
+ FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation,
+ FrameBuffer::Dimension from_dimension) {
+ BoundingBox to_box = from_box;
+ OrientParams params = GetOrientParams(from_orientation, to_orientation);
+ // First, rotate if needed.
+ if (params.rotation_angle_deg > 0) {
+ to_box =
+ RotateBoundingBox(to_box, params.rotation_angle_deg, from_dimension);
+ }
+ // Then perform horizontal or vertical flip if needed.
+ FrameBuffer::Dimension to_dimension = from_dimension;
+ if (params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270) {
+ to_dimension.Swap();
+ }
+ if (params.flip == OrientParams::FlipType::kVertical) {
+ to_box.set_origin_y(to_dimension.height -
+ (to_box.origin_y() + to_box.height()));
+ }
+ if (params.flip == OrientParams::FlipType::kHorizontal) {
+ to_box.set_origin_x(to_dimension.width -
+ (to_box.origin_x() + to_box.width()));
+ }
+ return to_box;
+}
+
+BoundingBox OrientAndDenormalizeBoundingBox(
+ float from_left, float from_top, float from_right, float from_bottom,
+ FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation,
+ FrameBuffer::Dimension from_dimension) {
+ BoundingBox from_box;
+ from_box.set_origin_x(from_left * from_dimension.width);
+ from_box.set_origin_y(from_top * from_dimension.height);
+ from_box.set_width(round(abs(from_right - from_left) * from_dimension.width));
+ from_box.set_height(
+ round(abs(from_bottom - from_top) * from_dimension.height));
+ BoundingBox to_box = OrientBoundingBox(from_box, from_orientation,
+ to_orientation, from_dimension);
+ return to_box;
+}
+
+void OrientCoordinates(int from_x, int from_y,
+ FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation,
+ FrameBuffer::Dimension from_dimension, int* to_x,
+ int* to_y) {
+ *to_x = from_x;
+ *to_y = from_y;
+ OrientParams params = GetOrientParams(from_orientation, to_orientation);
+ // First, rotate if needed.
+ if (params.rotation_angle_deg > 0) {
+ RotateCoordinates(from_x, from_y, params.rotation_angle_deg, from_dimension,
+ to_x, to_y);
+ }
+ // Then perform horizontal or vertical flip if needed.
+ FrameBuffer::Dimension to_dimension = from_dimension;
+ if (params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270) {
+ to_dimension.Swap();
+ }
+ if (params.flip == OrientParams::FlipType::kVertical) {
+ *to_y = to_dimension.height - *to_y - 1;
+ }
+ if (params.flip == OrientParams::FlipType::kHorizontal) {
+ *to_x = to_dimension.width - *to_x - 1;
+ }
+}
+
+// The algorithm is based on grouping orientations into two groups with specific
+// order. The two groups of orientation are {1, 6, 3, 8} and {2, 5, 4, 7}. See
+// image (https://www.impulseadventure.com/photo/images/orient_flag.gif) for
+// the visual grouping illustration.
+//
+// Each group contains elements can be transformed into one another by rotation.
+// The elements order within a group is important such that the distance between
+// the elements indicates the multiples of 90 degree needed to orient from one
+// element to another. For example, to orient element 1 to element 6, a 90
+// degree CCW rotation is needed.
+//
+// The corresponding order between the two groups is important such that the
+// even index defined the need for horizontal flipping and the odd index defined
+// the need for vertical flipping. For example, to orient element 1 to element 2
+// (even index) a horizontal flipping is needed.
+//
+// The implementation determines the group and element index of from and to
+// orientations. Based on the group and element index information, the above
+// characteristic is used to calculate the rotation angle and the need for
+// horizontal or vertical flipping.
+OrientParams GetOrientParams(FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation) {
+ int from_index = GetOrientationIndex(from_orientation);
+ int to_index = GetOrientationIndex(to_orientation);
+ int angle = 0;
+ absl::optional<OrientParams::FlipType> flip;
+
+ TFLITE_DCHECK(from_index > -1 && to_index > -1);
+
+ if ((from_index < kExifGroupSize && to_index < kExifGroupSize) ||
+ (from_index >= kExifGroupSize && to_index >= kExifGroupSize)) {
+ // Only needs rotation.
+
+ // The orientations' position differences translates to how many
+ // multiple of 90 degrees it needs for conversion. The position difference
+ // calculation within a group is circular.
+ angle = (kExifGroupSize - (from_index - to_index)) % kExifGroupSize * 90;
+ } else {
+ // Needs rotation and flipping.
+ int from_index_mod = from_index % kExifGroupSize;
+ int to_index_mod = to_index % kExifGroupSize;
+ angle = (kExifGroupSize - (from_index_mod - to_index_mod)) %
+ kExifGroupSize * 90;
+ if (to_index_mod % 2 == 1) {
+ flip = OrientParams::FlipType::kVertical;
+ } else {
+ flip = OrientParams::FlipType::kHorizontal;
+ }
+ }
+ return {angle, flip};
+}
+
+bool RequireDimensionSwap(FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation) {
+ OrientParams params = GetOrientParams(from_orientation, to_orientation);
+ return params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270;
+}
+
+absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer, int x0, int y0,
+ int x1, int y1,
+ FrameBuffer* output_buffer) {
+ TFLITE_DCHECK(utils_ != nullptr);
+ return utils_->Crop(buffer, x0, y0, x1, y1, output_buffer);
+}
+
+FrameBuffer::Dimension FrameBufferUtils::GetSize(
+ const FrameBuffer& buffer, const FrameBufferOperation& operation) {
+ FrameBuffer::Dimension dimension = buffer.dimension();
+ if (absl::holds_alternative<OrientOperation>(operation)) {
+ OrientParams params =
+ GetOrientParams(buffer.orientation(),
+ absl::get<OrientOperation>(operation).to_orientation);
+ if (params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270) {
+ dimension.Swap();
+ }
+ } else if (absl::holds_alternative<CropResizeOperation>(operation)) {
+ const auto& crop_resize = absl::get<CropResizeOperation>(operation);
+ dimension = crop_resize.resize_dimension;
+ }
+ return dimension;
+}
+
+std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes(
+ const uint8* buffer, FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format) {
+ std::vector<FrameBuffer::Plane> planes;
+ switch (format) {
+ case FrameBuffer::Format::kGRAY:
+ planes.push_back({/*buffer=*/buffer,
+ /*stride=*/{/*row_stride_bytes=*/dimension.width * 1,
+ /*pixel_stride_bytes=*/1}});
+ break;
+ case FrameBuffer::Format::kRGB:
+ planes.push_back({/*buffer=*/buffer,
+ /*stride=*/{/*row_stride_bytes=*/dimension.width * 3,
+ /*pixel_stride_bytes=*/3}});
+ break;
+ case FrameBuffer::Format::kRGBA:
+ planes.push_back({/*buffer=*/buffer,
+ /*stride=*/{/*row_stride_bytes=*/dimension.width * 4,
+ /*pixel_stride_bytes=*/4}});
+ break;
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kNV12: {
+ planes.push_back(
+ {buffer, /*stride=*/{/*row_stride_bytes=*/dimension.width,
+ /*pixel_stride_bytes=*/1}});
+ planes.push_back({buffer + (dimension.width * dimension.height),
+ /*stride=*/{/*row_stride_bytes=*/dimension.width,
+ /*pixel_stride_bytes=*/2}});
+ } break;
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21: {
+ const int y_buffer_size = dimension.width * dimension.height;
+ const int uv_row_stride = (dimension.width + 1) / 2;
+ const int uv_buffer_size = uv_row_stride * (dimension.height + 1) / 2;
+ planes.push_back(
+ {buffer, /*stride=*/{/*row_stride_bytes=*/dimension.width,
+ /*pixel_stride_bytes=*/1}});
+ planes.push_back(
+ {buffer + y_buffer_size, /*stride=*/{
+ /*row_stride_bytes=*/uv_row_stride, /*pixel_stride_bytes=*/1}});
+ planes.push_back(
+ {buffer + y_buffer_size + uv_buffer_size, /*stride=*/{
+ /*row_stride_bytes=*/uv_row_stride, /*pixel_stride_bytes=*/1}});
+ } break;
+ default:
+ break;
+ }
+ return planes;
+}
+
+FrameBuffer::Orientation FrameBufferUtils::GetOrientation(
+ const FrameBuffer& buffer, const FrameBufferOperation& operation) {
+ if (absl::holds_alternative<OrientOperation>(operation)) {
+ return absl::get<OrientOperation>(operation).to_orientation;
+ }
+ return buffer.orientation();
+}
+
+FrameBuffer::Format FrameBufferUtils::GetFormat(
+ const FrameBuffer& buffer, const FrameBufferOperation& operation) {
+ if (absl::holds_alternative<ConvertOperation>(operation)) {
+ return absl::get<ConvertOperation>(operation).to_format;
+ }
+ return buffer.format();
+}
+
+absl::Status FrameBufferUtils::Execute(const FrameBuffer& buffer,
+ const FrameBufferOperation& operation,
+ FrameBuffer* output_buffer) {
+ if (absl::holds_alternative<CropResizeOperation>(operation)) {
+ const auto& params = absl::get<CropResizeOperation>(operation);
+ RETURN_IF_ERROR(
+ Crop(buffer, params.crop_origin_x, params.crop_origin_y,
+ (params.crop_dimension.width + params.crop_origin_x - 1),
+ (params.crop_dimension.height + params.crop_origin_y - 1),
+ output_buffer));
+ } else if (absl::holds_alternative<ConvertOperation>(operation)) {
+ RETURN_IF_ERROR(Convert(buffer, output_buffer));
+ } else if (absl::holds_alternative<OrientOperation>(operation)) {
+ RETURN_IF_ERROR(Orient(buffer, output_buffer));
+ } else {
+ return absl::UnimplementedError(absl::StrFormat(
+ "FrameBufferOperation %i is not supported.", operation.index()));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status FrameBufferUtils::Resize(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ TFLITE_DCHECK(utils_ != nullptr);
+ return utils_->Resize(buffer, output_buffer);
+}
+
+absl::Status FrameBufferUtils::Rotate(const FrameBuffer& buffer,
+ RotationDegree rotation,
+ FrameBuffer* output_buffer) {
+ TFLITE_DCHECK(utils_ != nullptr);
+ return utils_->Rotate(buffer, 90 * static_cast<int>(rotation), output_buffer);
+}
+
+absl::Status FrameBufferUtils::FlipHorizontally(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ TFLITE_DCHECK(utils_ != nullptr);
+ return utils_->FlipHorizontally(buffer, output_buffer);
+}
+
+absl::Status FrameBufferUtils::FlipVertically(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ TFLITE_DCHECK(utils_ != nullptr);
+ return utils_->FlipVertically(buffer, output_buffer);
+}
+
+absl::Status FrameBufferUtils::Convert(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ TFLITE_DCHECK(utils_ != nullptr);
+ return utils_->Convert(buffer, output_buffer);
+}
+
+absl::Status FrameBufferUtils::Orient(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ TFLITE_DCHECK(utils_ != nullptr);
+
+ OrientParams params =
+ GetOrientParams(buffer.orientation(), output_buffer->orientation());
+ if (params.rotation_angle_deg == 0 && !params.flip.has_value()) {
+ // If no rotation or flip is needed, we will copy the buffer to
+ // output_buffer.
+ return utils_->Resize(buffer, output_buffer);
+ }
+
+ if (params.rotation_angle_deg == 0) {
+ // Only perform flip operation.
+ switch (*params.flip) {
+ case OrientParams::FlipType::kHorizontal:
+ return utils_->FlipHorizontally(buffer, output_buffer);
+ case OrientParams::FlipType::kVertical:
+ return utils_->FlipVertically(buffer, output_buffer);
+ }
+ }
+
+ if (!params.flip.has_value()) {
+ // Only perform rotation operation.
+ return utils_->Rotate(buffer, params.rotation_angle_deg, output_buffer);
+ }
+
+ // Perform rotation and flip operations.
+ // Create a temporary buffer to hold the rotation result.
+ auto tmp_buffer = absl::make_unique<uint8[]>(
+ GetBufferByteSize(output_buffer->dimension(), output_buffer->format()));
+ auto tmp_frame_buffer = FrameBuffer::Create(
+ GetPlanes(tmp_buffer.get(), output_buffer->dimension(),
+ output_buffer->format()),
+ output_buffer->dimension(), buffer.format(), buffer.orientation());
+
+ RETURN_IF_ERROR(utils_->Rotate(buffer, params.rotation_angle_deg,
+ tmp_frame_buffer.get()));
+ if (params.flip == OrientParams::FlipType::kHorizontal) {
+ return utils_->FlipHorizontally(*tmp_frame_buffer, output_buffer);
+ } else {
+ return utils_->FlipVertically(*tmp_frame_buffer, output_buffer);
+ }
+}
+
+absl::Status FrameBufferUtils::Execute(
+ const FrameBuffer& buffer,
+ const std::vector<FrameBufferOperation>& operations,
+ FrameBuffer* output_buffer) {
+ // Reference variables to swapping input and output buffers for each command.
+ FrameBuffer input_frame_buffer = buffer;
+ FrameBuffer temp_frame_buffer = buffer;
+
+ // Temporary buffers and its size to hold intermediate results.
+ int buffer1_size = 0;
+ int buffer2_size = 0;
+ std::unique_ptr<uint8[]> buffer1;
+ std::unique_ptr<uint8[]> buffer2;
+
+ for (int i = 0; i < operations.size(); i++) {
+ const FrameBufferOperation& operation = operations[i];
+
+ // The first command's input is always passed in `buffer`. Before
+ // process each command, the input_frame_buffer is pointed at the previous
+ // command's output buffer.
+ if (i == 0) {
+ input_frame_buffer = buffer;
+ } else {
+ input_frame_buffer = temp_frame_buffer;
+ }
+
+ // Calculates the resulting metadata from the command and the input.
+ FrameBuffer::Dimension new_size = GetSize(input_frame_buffer, operation);
+ FrameBuffer::Orientation new_orientation =
+ GetOrientation(input_frame_buffer, operation);
+ FrameBuffer::Format new_format = GetFormat(input_frame_buffer, operation);
+ int byte_size = GetBufferByteSize(new_size, new_format);
+
+ // The last command's output buffer is always passed in `output_buffer`.
+ // For other commands, we create temporary FrameBuffer for processing.
+ if ((i + 1) == operations.size()) {
+ temp_frame_buffer = *output_buffer;
+ // Validate the `output_buffer` metadata mathes with command line chain
+ // resulting metadata.
+ if (temp_frame_buffer.format() != new_format ||
+ temp_frame_buffer.orientation() != new_orientation ||
+ temp_frame_buffer.dimension() != new_size) {
+ return absl::InvalidArgumentError(
+ "The output metadata does not match pipeline result metadata.");
+ }
+ } else {
+ // Create a temporary buffer to hold intermediate results. For simplicity,
+ // we only create one continuous memory with no padding for intermediate
+ // results.
+ //
+ // We hold maximum 2 temporary buffers in memory at any given time.
+ //
+ // The pipeline is a linear chain. The output buffer from previous command
+ // becomes the input buffer for the next command. We simply use odd / even
+ // index to swap between buffers.
+ std::vector<FrameBuffer::Plane> planes;
+ if (i % 2 == 0) {
+ if (buffer1_size < byte_size) {
+ buffer1_size = byte_size;
+ buffer1 = absl::make_unique<uint8[]>(byte_size);
+ }
+ planes = GetPlanes(buffer1.get(), new_size, new_format);
+ } else {
+ if (buffer2_size < byte_size) {
+ buffer2_size = byte_size;
+ buffer2 = absl::make_unique<uint8[]>(byte_size);
+ }
+ planes = GetPlanes(buffer2.get(), new_size, new_format);
+ }
+ if (planes.empty()) {
+ return absl::InternalError("Failed to construct temporary buffer.");
+ }
+ temp_frame_buffer = FrameBuffer(planes, new_size, new_format,
+ new_orientation, buffer.timestamp());
+ }
+ RETURN_IF_ERROR(Execute(input_frame_buffer, operation, &temp_frame_buffer));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status FrameBufferUtils::Preprocess(
+ const FrameBuffer& buffer, absl::optional<BoundingBox> bounding_box,
+ FrameBuffer* output_buffer) {
+ std::vector<FrameBufferOperation> frame_buffer_operations;
+ // Handle cropping and resizing.
+ bool needs_dimension_swap =
+ RequireDimensionSwap(buffer.orientation(), output_buffer->orientation());
+ // For intermediate steps, we need to use dimensions based on the input
+ // orientation.
+ FrameBuffer::Dimension pre_orient_dimension = output_buffer->dimension();
+ if (needs_dimension_swap) {
+ pre_orient_dimension.Swap();
+ }
+
+ if (bounding_box.has_value()) {
+ // Cropping case.
+ frame_buffer_operations.push_back(CropResizeOperation(
+ bounding_box.value().origin_x(), bounding_box.value().origin_y(),
+ FrameBuffer::Dimension{bounding_box.value().width(),
+ bounding_box.value().height()},
+ pre_orient_dimension));
+ } else if (pre_orient_dimension != buffer.dimension()) {
+ // Resizing case.
+ frame_buffer_operations.push_back(
+ CropResizeOperation(0, 0, buffer.dimension(), pre_orient_dimension));
+ }
+
+ // Handle color space conversion.
+ if (output_buffer->format() != buffer.format()) {
+ frame_buffer_operations.push_back(
+ ConvertOperation(output_buffer->format()));
+ }
+
+ // Handle orientation conversion.
+ if (output_buffer->orientation() != buffer.orientation()) {
+ frame_buffer_operations.push_back(
+ OrientOperation(output_buffer->orientation()));
+ }
+
+ // Execute the processing pipeline.
+ if (frame_buffer_operations.empty()) {
+ // Using resize to perform copy.
+ RETURN_IF_ERROR(Resize(buffer, output_buffer));
+ } else {
+ RETURN_IF_ERROR(Execute(buffer, frame_buffer_operations, output_buffer));
+ }
+ return absl::OkStatus();
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
new file mode 100644
index 00000000..90a7491e
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h
@@ -0,0 +1,292 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_H_
+
+#include <memory>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/types/optional.h"
+#include "absl/types/variant.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Returns the minimal buffer size for a plane in bytes based on the given
+// format and dimensions.
+int GetBufferByteSize(FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format);
+
+// Rotates the `from_box` in `from_orientation` to `to_orientation` within an
+// image of size `from_dimension`.
+BoundingBox OrientBoundingBox(const BoundingBox& from_box,
+ FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation,
+ FrameBuffer::Dimension from_dimension);
+
+// Same as OrientBoundingBox but from normalized coordinates.
+BoundingBox OrientAndDenormalizeBoundingBox(
+ float from_left, float from_top, float from_right, float from_bottom,
+ FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation,
+ FrameBuffer::Dimension from_dimension);
+
+// Rotates `(from_x, from_y)` coordinates from an image of dimension
+// `from_dimension` and orientation `from_orientation` into `(to_x, to_y)`
+// coordinates with orientation `to_orientation`.
+void OrientCoordinates(int from_x, int from_y,
+ FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation,
+ FrameBuffer::Dimension from_dimension, int* to_x,
+ int* to_y);
+
+// Returns whether the conversion from from_orientation to to_orientation
+// requires 90 or 270 degrees rotation.
+bool RequireDimensionSwap(FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation);
+
+// Structure to express parameters needed to achieve orientation conversion.
+struct OrientParams {
+ // Counterclockwise rotation angle in degrees. This is expressed as a
+ // multiple of 90 degrees.
+ int rotation_angle_deg;
+ // Flipping operation. It must come after the rotation.
+ enum class FlipType { kHorizontal, kVertical };
+ absl::optional<FlipType> flip;
+};
+
+// Returns rotation angle and the need for horizontal flipping or vertical
+// flipping.
+OrientParams GetOrientParams(FrameBuffer::Orientation from_orientation,
+ FrameBuffer::Orientation to_orientation);
+
+// The parameters needed to crop / resize.
+//
+// The coordinate system has its origin at the upper left corner, and
+// positive values extend down and to the right from it.
+//
+// After the operation, the `crop_origin` will become the new origin.
+// `crop_width` and `crop_height` defines the desired cropping region. After
+// cropping, a resize is performed based on the `resize_width` and
+// `resize_height`.
+//
+// To perform just cropping, the `crop_width` and `crop_height` should be the
+// same as `resize_width` `and resize_height`.
+struct CropResizeOperation {
+ CropResizeOperation(int crop_origin_x, int crop_origin_y,
+ FrameBuffer::Dimension crop_dimension,
+ FrameBuffer::Dimension resize_dimension)
+ : crop_origin_x(crop_origin_x),
+ crop_origin_y(crop_origin_y),
+ crop_dimension(crop_dimension),
+ resize_dimension(resize_dimension) {}
+
+ int crop_origin_x;
+ int crop_origin_y;
+ FrameBuffer::Dimension crop_dimension;
+ FrameBuffer::Dimension resize_dimension;
+};
+
+// The parameters needed to convert to the specified format.
+struct ConvertOperation {
+ explicit ConvertOperation(FrameBuffer::Format to_format)
+ : to_format(to_format) {}
+ FrameBuffer::Format to_format;
+};
+
+// The parameters needed to change the orientation.
+struct OrientOperation {
+ explicit OrientOperation(FrameBuffer::Orientation to_orientation)
+ : to_orientation(to_orientation) {}
+ FrameBuffer::Orientation to_orientation;
+};
+
+// A variant of the supported operations on FrameBuffers. Alias for user
+// convenience.
+using FrameBufferOperation =
+ absl::variant<CropResizeOperation, ConvertOperation, OrientOperation>;
+
+// Image processing utility. This utility provides both basic image buffer
+// manipulations (e.g. rotation, format conversion, resizing, etc) as well as
+// capability for chaining pipeline executions. The actual buffer processing
+// engine is configurable to allow optimization based on platforms.
+//
+// Examples:
+//
+// // Create an instance of FrameBufferUtils with Halide processing engine.
+// std::unique_ptr<FrameBufferUtils> utils = FrameBufferUtils::Create(kHalide);
+//
+// // Perform single basic operation by each individual call.
+// std::unique_ptr<FrameBuffer> input = FrameBuffer::Create(...);
+// std::unique_ptr<FrameBuffer> output = FrameBuffer::Create(...);
+// utils->Orient(*input, output.get());
+// utils->Resize(*input, output.get());
+//
+// // Chaining processing operations.
+// const std::vector<FrameBufferOperation> operations = {
+// ConvertOperation(FrameBuffer::Format::kNV21),
+// CropResizeOperation(/*crop_origin_x=*/20, /*crop_origin_y=*/20,
+// /*crop_width=*/10, /*crop_height=*/10,
+// /*resize_width=*/10, /*resize_height=*/10),
+// OrientOperation(FrameBuffer::Orientation::kLeftTop)};
+// utils->Execute(*input, operations, output.get());
+class FrameBufferUtils {
+ public:
+ // Counter-clockwise rotation in degree.
+ enum class RotationDegree { k0 = 0, k90 = 1, k180 = 2, k270 = 3 };
+
+ // Underlying process engine used for performing operations.
+ enum class ProcessEngine {
+ kLibyuv,
+ };
+
+ // Factory method FrameBufferUtils instance. The processing engine is
+ // defined by `engine`.
+ static std::unique_ptr<FrameBufferUtils> Create(ProcessEngine engine) {
+ return absl::make_unique<FrameBufferUtils>(engine);
+ }
+
+ explicit FrameBufferUtils(ProcessEngine engine);
+
+ // Performs cropping operation.
+ //
+ // The coordinate system has its origin at the upper left corner, and
+ // positive values extend down and to the right from it. After cropping,
+ // (x0, y0) becomes (0, 0). The new width and height are
+ // (x1 - x0 + 1, y1 - y0 + 1).
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result. If the `output_buffer`
+ // size dimension does not match with crop dimension, then a resize is
+ // automatically performed.
+ absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
+ FrameBuffer* output_buffer);
+
+ // Performs resizing operation.
+ //
+ // The resize dimension is determined based on output_buffer's size metadata.
+ //
+ // The output_buffer should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ absl::Status Resize(const FrameBuffer& buffer, FrameBuffer* output_buffer);
+
+ // Performs rotation operation.
+ //
+ // The rotation is specified in counter-clockwise direction.
+ //
+ // The output_buffer should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ absl::Status Rotate(const FrameBuffer& buffer, RotationDegree rotation,
+ FrameBuffer* output_buffer);
+
+ // Performs horizontal flip operation.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ absl::Status FlipHorizontally(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer);
+
+ // Performs vertical flip operation.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ absl::Status FlipVertically(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer);
+
+ // Performs buffer format conversion.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ absl::Status Convert(const FrameBuffer& buffer, FrameBuffer* output_buffer);
+
+ // Performs buffer orientation conversion. Depends on the orientations, this
+ // method may perform rotation and optional flipping operations.
+ //
+ // If `buffer` and `output_buffer` has the same orientation, then a copy
+ // operation will performed.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ absl::Status Orient(const FrameBuffer& buffer, FrameBuffer* output_buffer);
+
+ // Performs the image processing operations specified, in that order.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ absl::Status Execute(const FrameBuffer& buffer,
+ const std::vector<FrameBufferOperation>& operations,
+ FrameBuffer* output_buffer);
+
+ // Performs a chain of operations to convert `buffer` to desired metadata
+ // (width, height, format, orientation) defined by `output_buffer` and
+ // optional cropping (`bounding_box`).
+ //
+ // Internally, a chain of operations is constructed. For performance
+ // optimization, operations are performed in the following order: crop,
+ // resize, convert color space format, and rotate.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result. Insufficient backing
+ // buffer size may cause garbage result or crash. Use `GetBufferByteSize` to
+ // calculate the minimal buffer size.
+ //
+ // If the `buffer` is already in desired format, then an extra copy will be
+ // performed.
+ //
+ // The input param `bounding_box` is defined in the `buffer` coordinate space.
+ absl::Status Preprocess(const FrameBuffer& buffer,
+ absl::optional<BoundingBox> bounding_box,
+ FrameBuffer* output_buffer);
+
+ private:
+ // Returns the new FrameBuffer size after the operation is applied.
+ FrameBuffer::Dimension GetSize(const FrameBuffer& buffer,
+ const FrameBufferOperation& operation);
+
+ // Returns the new FrameBuffer orientation after command is processed.
+ FrameBuffer::Orientation GetOrientation(
+ const FrameBuffer& buffer, const FrameBufferOperation& operation);
+
+ // Returns the new FrameBuffer format after command is processed.
+ FrameBuffer::Format GetFormat(const FrameBuffer& buffer,
+ const FrameBufferOperation& operation);
+
+ // Returns Plane struct based on one dimension buffer and its metadata. If
+ // an error occurred, it will return an empty vector.
+ std::vector<FrameBuffer::Plane> GetPlanes(const uint8* buffer,
+ FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format);
+
+ // Executes command with params.
+ absl::Status Execute(const FrameBuffer& buffer,
+ const FrameBufferOperation& operation,
+ FrameBuffer* output_buffer);
+
+ // Execution engine conforms to FrameBufferUtilsInterface.
+ std::unique_ptr<FrameBufferUtilsInterface> utils_;
+};
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_H_
diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
new file mode 100644
index 00000000..502e998d
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h
@@ -0,0 +1,88 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_INTERFACE_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_INTERFACE_H_
+
+#include "absl/status/status.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Interface for the FrameBuffer image processing library.
+class FrameBufferUtilsInterface {
+ public:
+ virtual ~FrameBufferUtilsInterface() = default;
+
+ // Crops `buffer` to the specified points.
+ //
+ // The coordinate system has its origin at the upper left corner, and
+ // positive values extend down and to the right from it. After cropping,
+ // the top left point becomes (0, 0). The new width and height are
+ // (x1 - x0 + 1, y1 - y0 + 1).
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ virtual absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1,
+ int y1, FrameBuffer* output_buffer) = 0;
+
+ // Resizes `buffer` to the size of the given `output_buffer`.
+ //
+ // The resize dimension is determined based on the size of `output_buffer`.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ virtual absl::Status Resize(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) = 0;
+
+ // Rotates `buffer` counter-clockwise by the given `angle_deg` (in degrees).
+ //
+ // When rotating by 90 degrees, the top-right corner of `buffer` becomes
+ // the top-left corner of `output_buffer`. The given angle must be a multiple
+ // of 90 degrees.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ virtual absl::Status Rotate(const FrameBuffer& buffer, int angle_deg,
+ FrameBuffer* output_buffer) = 0;
+
+ // Flips `buffer` horizontally.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ virtual absl::Status FlipHorizontally(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) = 0;
+
+ // Flips `buffer` vertically.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ virtual absl::Status FlipVertically(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) = 0;
+
+ // Converts `buffer`'s format to the format of the given `output_buffer`.
+ //
+ // The `output_buffer` should have metadata populated and its backing buffer
+ // should be big enough to store the operation result.
+ virtual absl::Status Convert(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) = 0;
+};
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_INTERFACE_H_
diff --git a/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc
new file mode 100644
index 00000000..51e72fa1
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc
@@ -0,0 +1,254 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h"
+
+#include "absl/status/status.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+namespace {
+
+using ::absl::StatusCode;
+using ::tflite::ColorSpaceType_RGB;
+using ::tflite::ContentProperties;
+using ::tflite::ContentProperties_ImageProperties;
+using ::tflite::EnumNameContentProperties;
+using ::tflite::ImageProperties;
+using ::tflite::TensorMetadata;
+using ::tflite::metadata::ModelMetadataExtractor;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+using ::tflite::task::core::TfLiteEngine;
+
+StatusOr<const TensorMetadata*> GetInputTensorMetadataIfAny(
+ const ModelMetadataExtractor& metadata_extractor) {
+ if (metadata_extractor.GetModelMetadata() == nullptr ||
+ metadata_extractor.GetModelMetadata()->subgraph_metadata() == nullptr) {
+ // Some models have no metadata at all (or very partial), so exit early.
+ return nullptr;
+ } else if (metadata_extractor.GetInputTensorCount() != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Models are assumed to have a single input TensorMetadata.",
+ TfLiteSupportStatus::kInvalidNumInputTensorsError);
+ }
+
+ const TensorMetadata* metadata = metadata_extractor.GetInputTensorMetadata(0);
+
+ if (metadata == nullptr) {
+ // Should never happen.
+ return CreateStatusWithPayload(StatusCode::kInternal,
+ "Input TensorMetadata is null.");
+ }
+
+ return metadata;
+}
+
+StatusOr<const ImageProperties*> GetImagePropertiesIfAny(
+ const TensorMetadata& tensor_metadata) {
+ if (tensor_metadata.content() == nullptr ||
+ tensor_metadata.content()->content_properties() == nullptr) {
+ return nullptr;
+ }
+
+ ContentProperties type = tensor_metadata.content()->content_properties_type();
+
+ if (type != ContentProperties_ImageProperties) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat(
+ "Expected ImageProperties for tensor ",
+ tensor_metadata.name() ? tensor_metadata.name()->str() : "#0",
+ ", got ", EnumNameContentProperties(type), "."),
+ TfLiteSupportStatus::kMetadataInvalidContentPropertiesError);
+ }
+
+ return tensor_metadata.content()->content_properties_as_ImageProperties();
+}
+
+StatusOr<absl::optional<NormalizationOptions>> GetNormalizationOptionsIfAny(
+ const TensorMetadata& tensor_metadata) {
+ ASSIGN_OR_RETURN(
+ const tflite::ProcessUnit* normalization_process_unit,
+ ModelMetadataExtractor::FindFirstProcessUnit(
+ tensor_metadata, tflite::ProcessUnitOptions_NormalizationOptions));
+ if (normalization_process_unit == nullptr) {
+ return {absl::nullopt};
+ }
+ const tflite::NormalizationOptions* tf_normalization_options =
+ normalization_process_unit->options_as_NormalizationOptions();
+ const auto mean_values = tf_normalization_options->mean();
+ const auto std_values = tf_normalization_options->std();
+ if (mean_values->size() != std_values->size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("NormalizationOptions: expected mean and std of same "
+ "dimension, got ",
+ mean_values->size(), " and ", std_values->size(), "."),
+ TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
+ }
+ absl::optional<NormalizationOptions> normalization_options;
+ if (mean_values->size() == 1) {
+ normalization_options = NormalizationOptions{
+ .mean_values = {mean_values->Get(0), mean_values->Get(0),
+ mean_values->Get(0)},
+ .std_values = {std_values->Get(0), std_values->Get(0),
+ std_values->Get(0)},
+ .num_values = 1};
+ } else if (mean_values->size() == 3) {
+ normalization_options = NormalizationOptions{
+ .mean_values = {mean_values->Get(0), mean_values->Get(1),
+ mean_values->Get(2)},
+ .std_values = {std_values->Get(0), std_values->Get(1),
+ std_values->Get(2)},
+ .num_values = 3};
+ } else {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("NormalizationOptions: only 1 or 3 mean and std "
+ "values are supported, got ",
+ mean_values->size(), "."),
+ TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
+ }
+ return normalization_options;
+}
+
+} // namespace
+
+StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
+ const TfLiteEngine::Interpreter& interpreter,
+ const tflite::metadata::ModelMetadataExtractor& metadata_extractor) {
+ ASSIGN_OR_RETURN(const TensorMetadata* metadata,
+ GetInputTensorMetadataIfAny(metadata_extractor));
+
+ const ImageProperties* props = nullptr;
+ absl::optional<NormalizationOptions> normalization_options;
+ if (metadata != nullptr) {
+ ASSIGN_OR_RETURN(props, GetImagePropertiesIfAny(*metadata));
+ ASSIGN_OR_RETURN(normalization_options,
+ GetNormalizationOptionsIfAny(*metadata));
+ }
+
+ if (TfLiteEngine::InputCount(&interpreter) != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Models are assumed to have a single input.",
+ TfLiteSupportStatus::kInvalidNumInputTensorsError);
+ }
+
+ // Input-related specifications.
+ const TfLiteTensor* input_tensor = TfLiteEngine::GetInput(&interpreter, 0);
+ if (input_tensor->dims->size != 4) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Only 4D tensors in BHWD layout are supported.",
+ TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
+ }
+ static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteFloat32};
+ TfLiteType input_type = input_tensor->type;
+ if (!absl::c_linear_search(valid_types, input_type)) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat(
+ "Type mismatch for input tensor ", input_tensor->name,
+ ". Requested one of these types: kTfLiteUint8/kTfLiteFloat32, got ",
+ TfLiteTypeGetName(input_type), "."),
+ TfLiteSupportStatus::kInvalidInputTensorTypeError);
+ }
+
+ // The expected layout is BHWD, i.e. batch x height x width x color
+ // See https://www.tensorflow.org/guide/tensors
+ const int batch = input_tensor->dims->data[0];
+ const int height = input_tensor->dims->data[1];
+ const int width = input_tensor->dims->data[2];
+ const int depth = input_tensor->dims->data[3];
+
+ if (props != nullptr && props->color_space() != ColorSpaceType_RGB) {
+ return CreateStatusWithPayload(StatusCode::kInvalidArgument,
+ "Only RGB color space is supported for now.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (batch != 1 || depth != 3) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("The input tensor should have dimensions 1 x height x "
+ "width x 3. Got ",
+ batch, " x ", height, " x ", width, " x ", depth, "."),
+ TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
+ }
+ int bytes_size = input_tensor->bytes;
+ size_t byte_depth =
+ input_type == kTfLiteFloat32 ? sizeof(float) : sizeof(uint8);
+
+ // Sanity checks.
+ if (input_type == kTfLiteFloat32) {
+ if (!normalization_options.has_value()) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kNotFound,
+ "Input tensor has type kTfLiteFloat32: it requires specifying "
+ "NormalizationOptions metadata to preprocess input images.",
+ TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError);
+ } else if (bytes_size / sizeof(float) %
+ normalization_options.value().num_values !=
+ 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "The number of elements in the input tensor must be a multiple of "
+ "the number of normalization parameters.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ }
+ if (width <= 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument, "The input width should be positive.",
+ TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
+ }
+ if (height <= 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument, "The input height should be positive.",
+ TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
+ }
+ if (bytes_size != height * width * depth * byte_depth) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "The input size in bytes does not correspond to the expected number of "
+ "pixels.",
+ TfLiteSupportStatus::kInvalidInputTensorSizeError);
+ }
+
+ // Note: in the future, additional checks against `props->default_size()`
+ // might be added. Also, verify that NormalizationOptions, if any, do specify
+ // a single value when color space is grayscale.
+
+ ImageTensorSpecs result;
+ result.image_width = width;
+ result.image_height = height;
+ result.color_space = ColorSpaceType_RGB;
+ result.tensor_type = input_type;
+ result.normalization_options = normalization_options;
+
+ return result;
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h
new file mode 100644
index 00000000..536eed4d
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h
@@ -0,0 +1,93 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
+
+#include <array>
+
+#include "absl/types/optional.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Parameters used for input image normalization when input tensor has
+// kTfLiteFloat32 type.
+//
+// Exactly 1 or 3 values are expected for `mean_values` and `std_values`. In
+// case 1 value only is specified, it is used for all channels. E.g. for a RGB
+// image, the normalization is done as follow:
+//
+// (R - mean_values[0]) / std_values[0]
+// (G - mean_values[1]) / std_values[1]
+// (B - mean_values[2]) / std_values[2]
+//
+// `num_values` keeps track of how many values have been provided, which should
+// be 1 or 3 (see above). In particular, single-channel grayscale images expect
+// only 1 value.
+struct NormalizationOptions {
+ std::array<float, 3> mean_values;
+ std::array<float, 3> std_values;
+ int num_values;
+};
+
+// Parameters related to the expected tensor specifications when the tensor
+// represents an image.
+//
+// E.g. input tensor specifications expected by the model at Invoke() time. In
+// such a case, and before running inference with the TF Lite interpreter, the
+// caller must use these values and perform image preprocessing and/or
+// normalization so as to fill the actual input tensor appropriately.
+struct ImageTensorSpecs {
+ // Expected image dimensions, e.g. image_width=224, image_height=224.
+ int image_width;
+ int image_height;
+ // Expected color space, e.g. color_space=RGB.
+ tflite::ColorSpaceType color_space;
+ // Expected input tensor type, e.g. if tensor_type=kTfLiteFloat32 the caller
+ // should usually perform some normalization to convert the uint8 pixels into
+ // floats (see NormalizationOptions in TF Lite Metadata for more details).
+ TfLiteType tensor_type;
+ // Optional normalization parameters read from TF Lite Metadata. Those are
+ // mandatory when tensor_type=kTfLiteFloat32 in order to convert the input
+ // image data into the expected range of floating point values, an error is
+ // returned otherwise (see sanity checks below). They should be ignored for
+ // other tensor input types, e.g. kTfLiteUInt8.
+ absl::optional<NormalizationOptions> normalization_options;
+};
+
+// Performs sanity checks on the expected input tensor including consistency
+// checks against model metadata, if any. For now, a single RGB input with BHWD
+// layout, where B = 1 and D = 3, is expected. Returns the corresponding input
+// specifications if they pass, or an error otherwise (too many input tensors,
+// etc).
+// Note: both interpreter and metadata extractor *must* be successfully
+// initialized before calling this function by means of (respectively):
+// - `tflite::InterpreterBuilder`,
+// - `tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer`.
+tflite::support::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
+ const tflite::task::core::TfLiteEngine::Interpreter& interpreter,
+ const tflite::metadata::ModelMetadataExtractor& metadata_extractor);
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
diff --git a/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
new file mode 100644
index 00000000..beb58eb4
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc
@@ -0,0 +1,1499 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h"
+
+#include <stdint.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/status/status.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "include/libyuv.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+using ::absl::StatusCode;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::TfLiteSupportStatus;
+
+namespace {
+
+// Converts NV12 `buffer` to the `output_buffer` of the target color space.
+// Supported output format includes RGB24 and YV21.
+absl::Status ConvertFromNv12(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ switch (output_buffer->format()) {
+ case FrameBuffer::Format::kRGB: {
+ // The RAW format of Libyuv represents the 8-bit interleaved RGB format in
+ // the big endian style with R being the first byte in memory.
+ int ret = libyuv::NV12ToRAW(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer,
+ yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv NV12ToRAW operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kRGBA: {
+ // The libyuv ABGR format is interleaved RGBA format in memory.
+ int ret = libyuv::NV12ToABGR(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer,
+ yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv NV12ToABGR operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ int ret = libyuv::NV12ToI420(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer,
+ yuv_data.uv_row_stride, const_cast<uint8_t*>(output_data.y_buffer),
+ output_data.y_row_stride, const_cast<uint8_t*>(output_data.u_buffer),
+ output_data.uv_row_stride, const_cast<uint8_t*>(output_data.v_buffer),
+ output_data.uv_row_stride, output_buffer->dimension().width,
+ output_buffer->dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv NV12ToI420 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kNV21: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride,
+ const_cast<uint8*>(output_data.y_buffer),
+ output_data.y_row_stride, buffer.dimension().width,
+ buffer.dimension().height);
+ ASSIGN_OR_RETURN(
+ const FrameBuffer::Dimension uv_plane_dimension,
+ GetUvPlaneDimension(buffer.dimension(), buffer.format()));
+ libyuv::SwapUVPlane(yuv_data.u_buffer, yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_data.v_buffer),
+ output_data.uv_row_stride, uv_plane_dimension.width,
+ uv_plane_dimension.height);
+ break;
+ }
+ case FrameBuffer::Format::kGRAY: {
+ libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ output_buffer->dimension().width,
+ output_buffer->dimension().height);
+ break;
+ }
+ default:
+ return absl::InternalError(absl::StrFormat("Format %i is not supported.",
+ output_buffer->format()));
+ }
+ return absl::OkStatus();
+}
+
+// Converts NV21 `buffer` into the `output_buffer` of the target color space.
+// Supported output format includes RGB24 and YV21.
+absl::Status ConvertFromNv21(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ switch (output_buffer->format()) {
+ case FrameBuffer::Format::kRGB: {
+ // The RAW format of Libyuv represents the 8-bit interleaved RGB format in
+ // the big endian style with R being the first byte in memory.
+ int ret = libyuv::NV21ToRAW(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.v_buffer,
+ yuv_data.uv_pixel_stride,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv NV21ToRAW operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kRGBA: {
+ // The libyuv ABGR format is interleaved RGBA format in memory.
+ int ret = libyuv::NV21ToABGR(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.v_buffer,
+ yuv_data.uv_pixel_stride,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv NV21ToABGR operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ int ret = libyuv::NV21ToI420(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.v_buffer,
+ yuv_data.uv_row_stride, const_cast<uint8_t*>(output_data.y_buffer),
+ output_data.y_row_stride, const_cast<uint8_t*>(output_data.u_buffer),
+ output_data.uv_row_stride, const_cast<uint8_t*>(output_data.v_buffer),
+ output_data.uv_row_stride, output_buffer->dimension().width,
+ output_buffer->dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv NV21ToI420 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kNV12: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride,
+ const_cast<uint8*>(output_data.y_buffer),
+ output_data.y_row_stride, buffer.dimension().width,
+ buffer.dimension().height);
+ ASSIGN_OR_RETURN(
+ const FrameBuffer::Dimension uv_plane_dimension,
+ GetUvPlaneDimension(buffer.dimension(), buffer.format()));
+ libyuv::SwapUVPlane(yuv_data.v_buffer, yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_data.u_buffer),
+ output_data.uv_row_stride, uv_plane_dimension.width,
+ uv_plane_dimension.height);
+ break;
+ }
+ case FrameBuffer::Format::kGRAY: {
+ libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ output_buffer->dimension().width,
+ output_buffer->dimension().height);
+ break;
+ }
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.",
+ output_buffer->format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ return absl::OkStatus();
+}
+
+// Converts YV12/YV21 `buffer` to the `output_buffer` of the target color space.
+// Supported output format includes RGB24, NV12, and NV21.
+absl::Status ConvertFromYv(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ switch (output_buffer->format()) {
+ case FrameBuffer::Format::kRGB: {
+ // The RAW format of Libyuv represents the 8-bit interleaved RGB format in
+ // the big endian style with R being the first byte in memory.
+ int ret = libyuv::I420ToRAW(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer,
+ yuv_data.uv_row_stride, yuv_data.v_buffer, yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv I420ToRAW operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kRGBA: {
+ // The libyuv ABGR format is interleaved RGBA format in memory.
+ int ret = libyuv::I420ToABGR(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer,
+ yuv_data.uv_row_stride, yuv_data.v_buffer, yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv I420ToABGR operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kNV12: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ int ret = libyuv::I420ToNV12(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer,
+ yuv_data.uv_row_stride, yuv_data.v_buffer, yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride,
+ const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride,
+ output_buffer->dimension().width, output_buffer->dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv I420ToNV12 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kNV21: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ int ret = libyuv::I420ToNV21(
+ yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer,
+ yuv_data.uv_row_stride, yuv_data.v_buffer, yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride,
+ const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride,
+ output_buffer->dimension().width, output_buffer->dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv I420ToNV21 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kGRAY: {
+ libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ output_buffer->dimension().width,
+ output_buffer->dimension().height);
+ break;
+ }
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_yuv_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ ASSIGN_OR_RETURN(
+ const FrameBuffer::Dimension uv_plane_dimension,
+ GetUvPlaneDimension(buffer.dimension(), buffer.format()));
+ libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride,
+ const_cast<uint8*>(output_yuv_data.y_buffer),
+ output_yuv_data.y_row_stride, buffer.dimension().width,
+ buffer.dimension().height);
+ libyuv::CopyPlane(yuv_data.u_buffer, yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_yuv_data.u_buffer),
+ output_yuv_data.uv_row_stride, uv_plane_dimension.width,
+ uv_plane_dimension.height);
+ libyuv::CopyPlane(yuv_data.v_buffer, yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_yuv_data.v_buffer),
+ output_yuv_data.uv_row_stride, uv_plane_dimension.width,
+ uv_plane_dimension.height);
+ break;
+ }
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.",
+ output_buffer->format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ return absl::OkStatus();
+}
+
+// Resizes YV12/YV21 `buffer` to the target `output_buffer`.
+absl::Status ResizeYv(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ // TODO(b/151217096): Choose the optimal image resizing filter to optimize
+ // the model inference performance.
+ int ret = libyuv::I420Scale(
+ input_data.y_buffer, input_data.y_row_stride, input_data.u_buffer,
+ input_data.uv_row_stride, input_data.v_buffer, input_data.uv_row_stride,
+ buffer.dimension().width, buffer.dimension().height,
+ const_cast<uint8_t*>(output_data.y_buffer), output_data.y_row_stride,
+ const_cast<uint8_t*>(output_data.u_buffer), output_data.uv_row_stride,
+ const_cast<uint8_t*>(output_data.v_buffer), output_data.uv_row_stride,
+ output_buffer->dimension().width, output_buffer->dimension().height,
+ libyuv::FilterMode::kFilterBilinear);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv I420Scale operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+}
+
+// Resizes NV12/NV21 `buffer` to the target `output_buffer`.
+absl::Status ResizeNv(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
+ const int buffer_size =
+ GetFrameBufferByteSize(buffer.dimension(), FrameBuffer::Format::kYV21);
+ auto yuv_raw_buffer = absl::make_unique<uint8[]>(buffer_size);
+ ASSIGN_OR_RETURN(
+ std::unique_ptr<FrameBuffer> yuv_buffer,
+ CreateFromRawBuffer(yuv_raw_buffer.get(), buffer.dimension(),
+ FrameBuffer::Format::kYV21, buffer.orientation()));
+ // TODO(b/151375918): Current implementation is a workaround by converting
+ // input NV12/NV21 buffer to the YV12 formats, resizing the YV12 buffer, and
+ // converting the resized YV12 buffer back to the target format. Consider
+ // optimizes this by adding the support of NV12/NV21 resizing in Libyuv.
+ if (buffer.format() == FrameBuffer::Format::kNV12) {
+ RETURN_IF_ERROR(ConvertFromNv12(buffer, yuv_buffer.get()));
+ } else if (buffer.format() == FrameBuffer::Format::kNV21) {
+ RETURN_IF_ERROR(ConvertFromNv21(buffer, yuv_buffer.get()));
+ } else {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ const int resized_buffer_size = GetFrameBufferByteSize(
+ output_buffer->dimension(), FrameBuffer::Format::kYV12);
+ auto resized_yuv_raw_buffer = absl::make_unique<uint8[]>(resized_buffer_size);
+ ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> resized_yuv_buffer,
+ CreateFromRawBuffer(resized_yuv_raw_buffer.get(),
+ output_buffer->dimension(),
+ FrameBuffer::Format::kYV12,
+ output_buffer->orientation()));
+ RETURN_IF_ERROR(ResizeYv(*yuv_buffer, resized_yuv_buffer.get()));
+
+ RETURN_IF_ERROR(ConvertFromYv(*resized_yuv_buffer, output_buffer));
+ return absl::OkStatus();
+}
+
+// Converts `buffer` to libyuv ARGB format and stores the conversion result
+// in `dest_argb`.
+absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb,
+ int dest_stride_argb) {
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
+ if (buffer.format() != FrameBuffer::Format::kRGB) {
+ return CreateStatusWithPayload(StatusCode::kInternal,
+ "RGB input format is expected.",
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ if (dest_argb == nullptr || dest_stride_argb <= 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ "Invalid destination arguments for ConvertRgbToArgb.",
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ int ret = libyuv::RGB24ToARGB(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ dest_argb, dest_stride_argb, buffer.dimension().width,
+ buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv RGB24ToARGB operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+}
+
+// Converts `src_argb` in libyuv ARGB format to FrameBuffer::kRGB format and
+// stores the conversion result in `output_buffer`.
+absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb,
+ FrameBuffer* output_buffer) {
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
+ if (output_buffer->format() != FrameBuffer::Format::kRGB) {
+ return absl::InternalError("RGB input format is expected.");
+ }
+
+ if (src_argb == nullptr || src_stride_argb <= 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal, "Invalid source arguments for ConvertArgbToRgb.",
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ if (output_buffer->plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ output_buffer->format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ int ret = libyuv::ARGBToRGB24(
+ src_argb, src_stride_argb,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ output_buffer->dimension().width, output_buffer->dimension().height);
+
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ARGBToRGB24 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+}
+
+// Converts `buffer` in FrameBuffer::kRGBA format to libyuv ARGB (BGRA in
+// memory) format and stores the conversion result in `dest_argb`.
+absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer, uint8* dest_argb,
+ int dest_stride_argb) {
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
+ if (buffer.format() != FrameBuffer::Format::kRGBA) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal, "RGBA input format is expected.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+
+ if (dest_argb == nullptr || dest_stride_argb <= 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ "Invalid source arguments for ConvertRgbaToArgb.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ int ret = libyuv::ABGRToARGB(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ dest_argb, dest_stride_argb, buffer.dimension().width,
+ buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal, "Libyuv ABGRToARGB operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+}
+
+// Converts kRGB `buffer` to the `output_buffer` of the target color space.
+absl::Status ConvertFromRgb(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ if (output_buffer->format() == FrameBuffer::Format::kGRAY) {
+ int ret = libyuv::RAWToJ400(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal, "Libyuv RAWToJ400 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+ } else if (output_buffer->format() == FrameBuffer::Format::kYV12 ||
+ output_buffer->format() == FrameBuffer::Format::kYV21 ||
+ output_buffer->format() == FrameBuffer::Format::kNV12 ||
+ output_buffer->format() == FrameBuffer::Format::kNV21) {
+ // libyuv does not support conversion directly from kRGB to kNV12 / kNV21.
+ // For kNV12 / kNV21, the implementation converts the kRGB to I420,
+ // then converts I420 to kNV12 / kNV21.
+ // TODO(b/153000936): use libyuv::RawToNV12 / libyuv::RawToNV21 when they
+ // are ready.
+ FrameBuffer::YuvData yuv_data;
+ std::unique_ptr<uint8[]> tmp_yuv_buffer;
+ std::unique_ptr<FrameBuffer> yuv_frame_buffer;
+ if (output_buffer->format() == FrameBuffer::Format::kNV12 ||
+ output_buffer->format() == FrameBuffer::Format::kNV21) {
+ tmp_yuv_buffer = absl::make_unique<uint8[]>(
+ GetFrameBufferByteSize(buffer.dimension(), output_buffer->format()));
+ ASSIGN_OR_RETURN(
+ yuv_frame_buffer,
+ CreateFromRawBuffer(tmp_yuv_buffer.get(), buffer.dimension(),
+ FrameBuffer::Format::kYV21,
+ output_buffer->orientation()));
+ ASSIGN_OR_RETURN(
+ yuv_data, FrameBuffer::GetYuvDataFromFrameBuffer(*yuv_frame_buffer));
+ } else {
+ ASSIGN_OR_RETURN(yuv_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ }
+ int ret = libyuv::RAWToI420(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(yuv_data.y_buffer), yuv_data.y_row_stride,
+ const_cast<uint8*>(yuv_data.u_buffer), yuv_data.uv_row_stride,
+ const_cast<uint8*>(yuv_data.v_buffer), yuv_data.uv_row_stride,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal, "Libyuv RAWToI420 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ if (output_buffer->format() == FrameBuffer::Format::kNV12 ||
+ output_buffer->format() == FrameBuffer::Format::kNV21) {
+ return ConvertFromYv(*yuv_frame_buffer, output_buffer);
+ }
+ return absl::OkStatus();
+ }
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", output_buffer->format()),
+ TfLiteSupportStatus::kImageProcessingError);
+}
+
+// Converts kRGBA `buffer` to the `output_buffer` of the target color space.
+absl::Status ConvertFromRgba(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ switch (output_buffer->format()) {
+ case FrameBuffer::Format::kGRAY: {
+ // libyuv does not support convert kRGBA (ABGR) foramat. In this method,
+ // the implementation converts kRGBA format to ARGB and use ARGB buffer
+ // for conversion.
+ // TODO(b/141181395): Use libyuv::ABGRToJ400 when it is ready.
+
+ // Convert kRGBA to ARGB
+ int argb_buffer_size = GetFrameBufferByteSize(buffer.dimension(),
+ FrameBuffer::Format::kRGBA);
+ auto argb_buffer = absl::make_unique<uint8[]>(argb_buffer_size);
+ const int argb_row_bytes = buffer.dimension().width * kRgbaPixelBytes;
+ RETURN_IF_ERROR(
+ ConvertRgbaToArgb(buffer, argb_buffer.get(), argb_row_bytes));
+
+ // Convert ARGB to kGRAY
+ int ret = libyuv::ARGBToJ400(
+ argb_buffer.get(), argb_row_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ARGBToJ400 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kNV12: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ int ret = libyuv::ABGRToNV12(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride,
+ const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ABGRToNV12 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kNV21: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ int ret = libyuv::ABGRToNV21(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride,
+ const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ABGRToNV21 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21: {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ int ret = libyuv::ABGRToI420(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride,
+ const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride,
+ const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ABGRToI420 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ case FrameBuffer::Format::kRGB: {
+ // ARGB is BGRA in memory and RGB24 is BGR in memory. The removal of the
+ // alpha channel will not impact the RGB ordering.
+ int ret = libyuv::ARGBToRGB24(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ABGRToRGB24 operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ break;
+ }
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Convert Rgba to format %i is not supported.",
+ output_buffer->format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ return absl::OkStatus();
+}
+
+// Returns libyuv rotation based on counter-clockwise angle_deg.
+libyuv::RotationMode GetLibyuvRotationMode(int angle_deg) {
+ switch (angle_deg) {
+ case 90:
+ return libyuv::kRotate270;
+ case 270:
+ return libyuv::kRotate90;
+ case 180:
+ return libyuv::kRotate180;
+ default:
+ return libyuv::kRotate0;
+ }
+}
+
+absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg,
+ FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ // libyuv::ARGBRotate assumes RGBA buffer is in the interleaved format.
+ int ret = libyuv::ARGBRotate(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes, buffer.dimension().width,
+ buffer.dimension().height, GetLibyuvRotationMode(angle_deg % 360));
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ARGBRotate operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+}
+
+absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg,
+ FrameBuffer* output_buffer) {
+ // libyuv does not support rotate kRGB (RGB24) foramat. In this method, the
+ // implementation converts kRGB format to ARGB and use ARGB buffer for
+ // rotation. The result is then convert back to RGB.
+
+ // Convert RGB to ARGB
+ int argb_buffer_size =
+ GetFrameBufferByteSize(buffer.dimension(), FrameBuffer::Format::kRGBA);
+ auto argb_buffer = absl::make_unique<uint8[]>(argb_buffer_size);
+ const int argb_row_bytes = buffer.dimension().width * kRgbaPixelBytes;
+ RETURN_IF_ERROR(ConvertRgbToArgb(buffer, argb_buffer.get(), argb_row_bytes));
+
+ // Rotate ARGB
+ auto argb_rotated_buffer = absl::make_unique<uint8[]>(argb_buffer_size);
+ int rotated_row_bytes = output_buffer->dimension().width * kRgbaPixelBytes;
+ // TODO(b/151954340): Optimize the current implementation by utilizing
+ // ARGBMirror for 180 degree rotation.
+ int ret = libyuv::ARGBRotate(
+ argb_buffer.get(), argb_row_bytes, argb_rotated_buffer.get(),
+ rotated_row_bytes, buffer.dimension().width, buffer.dimension().height,
+ GetLibyuvRotationMode(angle_deg % 360));
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ARGBRotate operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+
+ // Convert ARGB to RGB
+ return ConvertArgbToRgb(argb_rotated_buffer.get(), rotated_row_bytes,
+ output_buffer);
+}
+
+absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg,
+ FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ int ret = libyuv::RotatePlane(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes, buffer.dimension().width,
+ buffer.dimension().height, GetLibyuvRotationMode(angle_deg % 360));
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv RotatePlane operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+}
+
+// Rotates YV12/YV21 frame buffer.
+absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ int ret = libyuv::I420Rotate(
+ input_data.y_buffer, input_data.y_row_stride, input_data.u_buffer,
+ input_data.uv_row_stride, input_data.v_buffer, input_data.uv_row_stride,
+ const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride,
+ const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride,
+ const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride,
+ buffer.dimension().width, buffer.dimension().height,
+ GetLibyuvRotationMode(angle_deg));
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv I420Rotate operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+}
+
+// Rotates NV12/NV21 frame buffer.
+// TODO(b/152097364): Refactor NV12/NV21 rotation after libyuv explicitly
+// support that.
+absl::Status RotateNv(const FrameBuffer& buffer, int angle_deg,
+ FrameBuffer* output_buffer) {
+ if (buffer.format() != FrameBuffer::Format::kNV12 &&
+ buffer.format() != FrameBuffer::Format::kNV21) {
+ return CreateStatusWithPayload(StatusCode::kInternal,
+ "kNV12 or kNV21 input formats are expected.",
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ const int rotated_buffer_size = GetFrameBufferByteSize(
+ output_buffer->dimension(), FrameBuffer::Format::kYV21);
+ auto rotated_yuv_raw_buffer = absl::make_unique<uint8[]>(rotated_buffer_size);
+ ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> rotated_yuv_buffer,
+ CreateFromRawBuffer(
+ rotated_yuv_raw_buffer.get(), output_buffer->dimension(),
+ /*target_format=*/FrameBuffer::Format::kYV21,
+ output_buffer->orientation()));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData rotated_yuv_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*rotated_yuv_buffer));
+ // Get the first chroma plane and use it as the u plane. This is a workaround
+ // for optimizing NV21 rotation. For NV12, the implementation is logical
+ // correct. For NV21, use v plane as u plane will make the UV planes swapped
+ // in the intermediate rotated I420 frame. The output buffer is finally built
+ // by merging the swapped UV planes which produces V first interleaved UV
+ // buffer.
+ const uint8* chroma_buffer = buffer.format() == FrameBuffer::Format::kNV12
+ ? input_data.u_buffer
+ : input_data.v_buffer;
+ // Rotate the Y plane and store into the Y plane in `output_buffer`. Rotate
+ // the interleaved UV plane and store into the interleaved UV plane in
+ // `rotated_yuv_buffer`.
+ int ret = libyuv::NV12ToI420Rotate(
+ input_data.y_buffer, input_data.y_row_stride, chroma_buffer,
+ input_data.uv_row_stride, const_cast<uint8*>(output_data.y_buffer),
+ output_data.y_row_stride, const_cast<uint8*>(rotated_yuv_data.u_buffer),
+ rotated_yuv_data.uv_row_stride,
+ const_cast<uint8*>(rotated_yuv_data.v_buffer),
+ rotated_yuv_data.uv_row_stride, buffer.dimension().width,
+ buffer.dimension().height, GetLibyuvRotationMode(angle_deg % 360));
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv Nv12ToI420Rotate operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ // Merge rotated UV planes into the output buffer. For NV21, the UV buffer of
+ // the intermediate I420 frame is swapped. MergeUVPlane builds the interleaved
+ // VU buffer for NV21 by putting the U plane in the I420 frame which is
+ // actually the V plane from the input buffer first.
+ const uint8* output_chroma_buffer =
+ buffer.format() == FrameBuffer::Format::kNV12 ? output_data.u_buffer
+ : output_data.v_buffer;
+ // The width and height arguments of `libyuv::MergeUVPlane()` represent the
+ // width and height of the UV planes.
+ libyuv::MergeUVPlane(
+ rotated_yuv_data.u_buffer, rotated_yuv_data.uv_row_stride,
+ rotated_yuv_data.v_buffer, rotated_yuv_data.uv_row_stride,
+ const_cast<uint8*>(output_chroma_buffer), output_data.uv_row_stride,
+ (output_buffer->dimension().width + 1) / 2,
+ (output_buffer->dimension().height + 1) / 2);
+ return absl::OkStatus();
+}
+
+// This method only supports kGRAY, kRGB, and kRGBA format.
+absl::Status FlipPlaneVertically(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ ASSIGN_OR_RETURN(int pixel_stride, GetPixelStrides(buffer.format()));
+
+ // Flip vertically is achieved by passing in negative height.
+ libyuv::CopyPlane(buffer.plane(0).buffer,
+ buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ output_buffer->dimension().width * pixel_stride,
+ -output_buffer->dimension().height);
+
+ return absl::OkStatus();
+}
+
+// This method only supports kGRAY, kRGBA, and kRGB formats.
+absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1,
+ int y1, FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ ASSIGN_OR_RETURN(int pixel_stride, GetPixelStrides(buffer.format()));
+ FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
+
+ // Cropping is achieved by adjusting origin to (x0, y0).
+ int adjusted_offset =
+ buffer.plane(0).stride.row_stride_bytes * y0 + x0 * pixel_stride;
+
+ libyuv::CopyPlane(buffer.plane(0).buffer + adjusted_offset,
+ buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ crop_dimension.width * pixel_stride, crop_dimension.height);
+
+ return absl::OkStatus();
+}
+
+// Crops NV12/NV21 FrameBuffer to the subregion defined by the top left pixel
+// position (x0, y0) and the bottom right pixel position (x1, y1).
+absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ // Crop Y plane by copying the buffer with the origin offset to (x0, y0).
+ int crop_offset_y = input_data.y_row_stride * y0 + x0;
+ int crop_width = x1 - x0 + 1;
+ int crop_height = y1 - y0 + 1;
+ libyuv::CopyPlane(input_data.y_buffer + crop_offset_y,
+ input_data.y_row_stride,
+ const_cast<uint8*>(output_data.y_buffer),
+ output_data.y_row_stride, crop_width, crop_height);
+ // Crop chroma plane by copying the buffer with the origin offset to
+ // (x0 / 2, y0 / 2);
+ // TODO(b/152629712): Investigate the impact of color shifting caused by the
+ // bounding box with odd X or Y starting positions.
+ int crop_offset_chroma = input_data.uv_row_stride * (y0 / 2) +
+ input_data.uv_pixel_stride * (x0 / 2);
+ ASSIGN_OR_RETURN(const uint8* input_chroma_buffer, GetUvRawBuffer(buffer));
+ ASSIGN_OR_RETURN(const uint8* output_chroma_buffer,
+ GetUvRawBuffer(*output_buffer));
+ libyuv::CopyPlane(
+ input_chroma_buffer + crop_offset_chroma, input_data.uv_row_stride,
+ const_cast<uint8*>(output_chroma_buffer), output_data.uv_row_stride,
+ /*width=*/(crop_width + 1) / 2 * 2, /*height=*/(crop_height + 1) / 2);
+ return absl::OkStatus();
+}
+
+// Crops YV12/YV21 FrameBuffer to the subregion defined by the top left pixel
+// position (x0, y0) and the bottom right pixel position (x1, y1).
+absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ // Crop Y plane by copying the buffer with the origin offset to (x0, y0).
+ int crop_offset_y = input_data.y_row_stride * y0 + x0;
+ FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
+ libyuv::CopyPlane(
+ input_data.y_buffer + crop_offset_y, input_data.y_row_stride,
+ const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride,
+ crop_dimension.width, crop_dimension.height);
+ // Crop U plane by copying the buffer with the origin offset to
+ // (x0 / 2, y0 / 2).
+ ASSIGN_OR_RETURN(const FrameBuffer::Dimension crop_uv_dimension,
+ GetUvPlaneDimension(crop_dimension, buffer.format()));
+ // TODO(b/152629712): Investigate the impact of color shifting caused by the
+ // bounding box with odd X or Y starting positions.
+ int crop_offset_chroma = input_data.uv_row_stride * (y0 / 2) +
+ input_data.uv_pixel_stride * (x0 / 2);
+ libyuv::CopyPlane(
+ input_data.u_buffer + crop_offset_chroma, input_data.uv_row_stride,
+ const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride,
+ crop_uv_dimension.width, crop_uv_dimension.height);
+ // Crop V plane by copying the buffer with the origin offset to
+ // (x0 / 2, y0 / 2);
+ libyuv::CopyPlane(
+ input_data.v_buffer + crop_offset_chroma, input_data.uv_row_stride,
+ const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride,
+ /*width=*/(crop_dimension.width + 1) / 2,
+ /*height=*/(crop_dimension.height + 1) / 2);
+ return absl::OkStatus();
+}
+
+absl::Status CropResizeYuv(const FrameBuffer& buffer, int x0, int y0, int x1,
+ int y1, FrameBuffer* output_buffer) {
+ FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
+ if (crop_dimension == output_buffer->dimension()) {
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ return CropNv(buffer, x0, y0, x1, y1, output_buffer);
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return CropYv(buffer, x0, y0, x1, y1, output_buffer);
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ }
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ // Cropping YUV planes by offsetting the origins of each plane.
+ // TODO(b/152629712): Investigate the impact of color shifting caused by the
+ // bounding box with odd X or Y starting positions.
+ const int plane_y_offset = input_data.y_row_stride * y0 + x0;
+ const int plane_uv_offset = input_data.uv_row_stride * (y0 / 2) +
+ input_data.uv_pixel_stride * (x0 / 2);
+ FrameBuffer::Plane cropped_plane_y = {
+ /*buffer=*/input_data.y_buffer + plane_y_offset,
+ /*stride=*/{input_data.y_row_stride, /*pixel_stride_bytes=*/1}};
+ FrameBuffer::Plane cropped_plane_u = {
+ /*buffer=*/input_data.u_buffer + plane_uv_offset,
+ /*stride=*/{input_data.uv_row_stride, input_data.uv_pixel_stride}};
+ FrameBuffer::Plane cropped_plane_v = {
+ /*buffer=*/input_data.v_buffer + plane_uv_offset,
+ /*stride=*/{input_data.uv_row_stride, input_data.uv_pixel_stride}};
+
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kNV12: {
+ std::unique_ptr<FrameBuffer> cropped_buffer = FrameBuffer::Create(
+ {cropped_plane_y, cropped_plane_u, cropped_plane_v}, crop_dimension,
+ buffer.format(), buffer.orientation());
+ return ResizeNv(*cropped_buffer, output_buffer);
+ }
+ case FrameBuffer::Format::kNV21: {
+ std::unique_ptr<FrameBuffer> cropped_buffer = FrameBuffer::Create(
+ {cropped_plane_y, cropped_plane_v, cropped_plane_u}, crop_dimension,
+ buffer.format(), buffer.orientation());
+ return ResizeNv(*cropped_buffer, output_buffer);
+ }
+ case FrameBuffer::Format::kYV12: {
+ std::unique_ptr<FrameBuffer> cropped_buffer = FrameBuffer::Create(
+ {cropped_plane_y, cropped_plane_v, cropped_plane_u}, crop_dimension,
+ buffer.format(), buffer.orientation());
+ return ResizeYv(*cropped_buffer, output_buffer);
+ }
+ case FrameBuffer::Format::kYV21: {
+ std::unique_ptr<FrameBuffer> cropped_buffer = FrameBuffer::Create(
+ {cropped_plane_y, cropped_plane_u, cropped_plane_v}, crop_dimension,
+ buffer.format(), buffer.orientation());
+ return ResizeYv(*cropped_buffer, output_buffer);
+ }
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ return absl::OkStatus();
+}
+
+absl::Status FlipHorizontallyRgba(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ int ret = libyuv::ARGBMirror(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ output_buffer->dimension().width, output_buffer->dimension().height);
+
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ARGBMirror operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+
+ return absl::OkStatus();
+}
+
+// Flips `buffer` horizontally and store the result in `output_buffer`. This
+// method assumes all buffers have pixel stride equals to 1.
+absl::Status FlipHorizontallyPlane(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ libyuv::MirrorPlane(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ output_buffer->dimension().width, output_buffer->dimension().height);
+
+ return absl::OkStatus();
+}
+
+absl::Status ResizeRgb(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+ // libyuv doesn't support scale kRGB (RGB24) foramat. In this method,
+ // the implementation converts kRGB format to ARGB and use ARGB buffer for
+ // scaling. The result is then convert back to RGB.
+
+ // Convert RGB to ARGB
+ int argb_buffer_size =
+ GetFrameBufferByteSize(buffer.dimension(), FrameBuffer::Format::kRGBA);
+ auto argb_buffer = absl::make_unique<uint8[]>(argb_buffer_size);
+ const int argb_row_bytes = buffer.dimension().width * kRgbaPixelBytes;
+ RETURN_IF_ERROR(ConvertRgbToArgb(buffer, argb_buffer.get(), argb_row_bytes));
+
+ // Resize ARGB
+ int resized_argb_buffer_size = GetFrameBufferByteSize(
+ output_buffer->dimension(), FrameBuffer::Format::kRGBA);
+ auto resized_argb_buffer =
+ absl::make_unique<uint8[]>(resized_argb_buffer_size);
+ int resized_argb_row_bytes =
+ output_buffer->dimension().width * kRgbaPixelBytes;
+ int ret = libyuv::ARGBScale(
+ argb_buffer.get(), argb_row_bytes, buffer.dimension().width,
+ buffer.dimension().height, resized_argb_buffer.get(),
+ resized_argb_row_bytes, output_buffer->dimension().width,
+ output_buffer->dimension().height, libyuv::FilterMode::kFilterBilinear);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ARGBScale operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+
+ // Convert ARGB to RGB
+ return ConvertArgbToRgb(resized_argb_buffer.get(), resized_argb_row_bytes,
+ output_buffer);
+}
+
+// Horizontally flip `buffer` and store the result in `output_buffer`.
+absl::Status FlipHorizontallyRgb(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+
+#if LIBYUV_VERSION >= 1747
+ int ret = libyuv::RGB24Mirror(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes, buffer.dimension().width,
+ buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv RGB24Mirror operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+
+ return absl::OkStatus();
+#else
+#error LibyuvFrameBufferUtils requires LIBYUV_VERSION 1747 or above
+#endif // LIBYUV_VERSION >= 1747
+}
+
+absl::Status ResizeRgba(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ int ret = libyuv::ARGBScale(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ output_buffer->dimension().width, output_buffer->dimension().height,
+ libyuv::FilterMode::kFilterBilinear);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv ARGBScale operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+}
+
+// Flips NV12/NV21 FrameBuffer horizontally.
+absl::Status FlipHorizontallyNv(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ ASSIGN_OR_RETURN(const uint8* input_chroma_buffer, GetUvRawBuffer(buffer));
+ ASSIGN_OR_RETURN(const uint8* output_chroma_buffer,
+ GetUvRawBuffer(*output_buffer));
+
+ int ret = libyuv::NV12Mirror(
+ input_data.y_buffer, input_data.y_row_stride, input_chroma_buffer,
+ input_data.uv_row_stride, const_cast<uint8*>(output_data.y_buffer),
+ output_data.y_row_stride, const_cast<uint8*>(output_chroma_buffer),
+ output_data.uv_row_stride, buffer.dimension().width,
+ buffer.dimension().height);
+
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv NV12Mirror operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+
+ return absl::OkStatus();
+}
+
+// Flips YV12/YV21 FrameBuffer horizontally.
+absl::Status FlipHorizontallyYv(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ int ret = libyuv::I420Mirror(
+ input_data.y_buffer, input_data.y_row_stride, input_data.u_buffer,
+ input_data.uv_row_stride, input_data.v_buffer, input_data.uv_row_stride,
+ const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride,
+ const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride,
+ const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride,
+ buffer.dimension().width, buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv I420Mirror operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+
+ return absl::OkStatus();
+}
+
+// Flips NV12/NV21 FrameBuffer vertically.
+absl::Status FlipVerticallyNv(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ // Flip Y plane vertically by passing a negative height.
+ libyuv::CopyPlane(input_data.y_buffer, input_data.y_row_stride,
+ const_cast<uint8*>(output_data.y_buffer),
+ output_data.y_row_stride, buffer.dimension().width,
+ -output_buffer->dimension().height);
+ // Flip UV plane vertically by passing a negative height.
+ ASSIGN_OR_RETURN(const uint8* input_chroma_buffer, GetUvRawBuffer(buffer));
+ ASSIGN_OR_RETURN(const uint8* output_chroma_buffer,
+ GetUvRawBuffer(*output_buffer));
+ ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_plane_dimension,
+ GetUvPlaneDimension(buffer.dimension(), buffer.format()));
+ libyuv::CopyPlane(
+ input_chroma_buffer, input_data.uv_row_stride,
+ const_cast<uint8*>(output_chroma_buffer), output_data.uv_row_stride,
+ /*width=*/uv_plane_dimension.width * 2, -uv_plane_dimension.height);
+ return absl::OkStatus();
+}
+
+// Flips NV12/NV21 FrameBuffer vertically.
+absl::Status FlipVerticallyYv(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer));
+ // Flip buffer vertically by passing a negative height.
+ int ret = libyuv::I420Copy(
+ input_data.y_buffer, input_data.y_row_stride, input_data.u_buffer,
+ input_data.uv_row_stride, input_data.v_buffer, input_data.uv_row_stride,
+ const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride,
+ const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride,
+ const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride,
+ buffer.dimension().width, -buffer.dimension().height);
+ if (ret != 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown, "Libyuv I420Copy operation failed.",
+ TfLiteSupportStatus::kImageProcessingBackendError);
+ }
+ return absl::OkStatus();
+}
+
+// Resize `buffer` to metadata defined in `output_buffer`. This
+// method assumes buffer has pixel stride equals to 1 (grayscale equivalent).
+absl::Status ResizeGray(const FrameBuffer& buffer, FrameBuffer* output_buffer) {
+ if (buffer.plane_count() > 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Only single plane is supported for format %i.",
+ buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+ libyuv::ScalePlane(
+ buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes,
+ buffer.dimension().width, buffer.dimension().height,
+ const_cast<uint8*>(output_buffer->plane(0).buffer),
+ output_buffer->plane(0).stride.row_stride_bytes,
+ output_buffer->dimension().width, output_buffer->dimension().height,
+ libyuv::FilterMode::kFilterBilinear);
+ return absl::OkStatus();
+}
+
+// This method only supports kGRAY, kRGBA, and kRGB formats.
+absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1,
+ int y1, FrameBuffer* output_buffer) {
+ FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1);
+ if (crop_dimension == output_buffer->dimension()) {
+ return CropPlane(buffer, x0, y0, x1, y1, output_buffer);
+ }
+
+ ASSIGN_OR_RETURN(int pixel_stride, GetPixelStrides(buffer.format()));
+ // Cropping is achieved by adjusting origin to (x0, y0).
+ int adjusted_offset =
+ buffer.plane(0).stride.row_stride_bytes * y0 + x0 * pixel_stride;
+ FrameBuffer::Plane plane = {
+ /*buffer=*/buffer.plane(0).buffer + adjusted_offset,
+ /*stride=*/{buffer.plane(0).stride.row_stride_bytes, pixel_stride}};
+ auto adjusted_buffer =
+ FrameBuffer::Create({plane}, crop_dimension, buffer.format(),
+ buffer.orientation(), buffer.timestamp());
+
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kRGB:
+ return ResizeRgb(*adjusted_buffer, output_buffer);
+ case FrameBuffer::Format::kRGBA:
+ return ResizeRgba(*adjusted_buffer, output_buffer);
+ case FrameBuffer::Format::kGRAY:
+ return ResizeGray(*adjusted_buffer, output_buffer);
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+}
+} // namespace
+
+absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer, int x0,
+ int y0, int x1, int y1,
+ FrameBuffer* output_buffer) {
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
+ RETURN_IF_ERROR(
+ ValidateCropBufferInputs(buffer, *output_buffer, x0, y0, x1, y1));
+ RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer));
+
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kRGBA:
+ case FrameBuffer::Format::kRGB:
+ case FrameBuffer::Format::kGRAY:
+ return CropResize(buffer, x0, y0, x1, y1, output_buffer);
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return CropResizeYuv(buffer, x0, y0, x1, y1, output_buffer);
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+}
+
+absl::Status LibyuvFrameBufferUtils::Resize(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ RETURN_IF_ERROR(ValidateResizeBufferInputs(buffer, *output_buffer));
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return ResizeYv(buffer, output_buffer);
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ return ResizeNv(buffer, output_buffer);
+ case FrameBuffer::Format::kRGB:
+ return ResizeRgb(buffer, output_buffer);
+ case FrameBuffer::Format::kRGBA:
+ return ResizeRgba(buffer, output_buffer);
+ case FrameBuffer::Format::kGRAY:
+ return ResizeGray(buffer, output_buffer);
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+}
+
+absl::Status LibyuvFrameBufferUtils::Rotate(const FrameBuffer& buffer,
+ int angle_deg,
+ FrameBuffer* output_buffer) {
+ RETURN_IF_ERROR(
+ ValidateRotateBufferInputs(buffer, *output_buffer, angle_deg));
+ RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer));
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
+
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kGRAY:
+ return RotateGray(buffer, angle_deg, output_buffer);
+ case FrameBuffer::Format::kRGBA:
+ return RotateRgba(buffer, angle_deg, output_buffer);
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ return RotateNv(buffer, angle_deg, output_buffer);
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return RotateYv(buffer, angle_deg, output_buffer);
+ case FrameBuffer::Format::kRGB:
+ return RotateRgb(buffer, angle_deg, output_buffer);
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+}
+
+absl::Status LibyuvFrameBufferUtils::FlipHorizontally(
+ const FrameBuffer& buffer, FrameBuffer* output_buffer) {
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
+ RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer));
+ RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer));
+
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kRGBA:
+ return FlipHorizontallyRgba(buffer, output_buffer);
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return FlipHorizontallyYv(buffer, output_buffer);
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ return FlipHorizontallyNv(buffer, output_buffer);
+ case FrameBuffer::Format::kRGB:
+ return FlipHorizontallyRgb(buffer, output_buffer);
+ case FrameBuffer::Format::kGRAY:
+ return FlipHorizontallyPlane(buffer, output_buffer);
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+}
+
+absl::Status LibyuvFrameBufferUtils::FlipVertically(
+ const FrameBuffer& buffer, FrameBuffer* output_buffer) {
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer));
+ RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer));
+ RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer));
+ RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer));
+
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kRGBA:
+ case FrameBuffer::Format::kRGB:
+ case FrameBuffer::Format::kGRAY:
+ return FlipPlaneVertically(buffer, output_buffer);
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ return FlipVerticallyNv(buffer, output_buffer);
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return FlipVerticallyYv(buffer, output_buffer);
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+}
+
+absl::Status LibyuvFrameBufferUtils::Convert(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) {
+ RETURN_IF_ERROR(
+ ValidateConvertFormats(buffer.format(), output_buffer->format()));
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kNV12:
+ return ConvertFromNv12(buffer, output_buffer);
+ case FrameBuffer::Format::kNV21:
+ return ConvertFromNv21(buffer, output_buffer);
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return ConvertFromYv(buffer, output_buffer);
+ case FrameBuffer::Format::kRGB:
+ return ConvertFromRgb(buffer, output_buffer);
+ case FrameBuffer::Format::kRGBA:
+ return ConvertFromRgba(buffer, output_buffer);
+ default:
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Format %i is not supported.", buffer.format()),
+ TfLiteSupportStatus::kImageProcessingError);
+ }
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h b/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
new file mode 100644
index 00000000..0d001c8c
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h
@@ -0,0 +1,76 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_LIBYUV_FRAME_BUFFER_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_LIBYUV_FRAME_BUFFER_UTILS_H_
+
+#include "absl/status/status.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Libyuv image processing engine conforms to FrameBufferUtilsInterface.
+// Although this class provides public APIs, it is recommended to use the public
+// APIs defined in frame_buffer_utils.h for higher level abstraction and better
+// functionality support.
+class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface {
+ public:
+ LibyuvFrameBufferUtils() = default;
+ ~LibyuvFrameBufferUtils() override = default;
+
+ // Crops input `buffer` to the specified subregions and resizes the cropped
+ // region to the target image resolution defined by the `output_buffer`.
+ //
+ // (x0, y0) represents the top-left point of the buffer.
+ // (x1, y1) represents the bottom-right point of the buffer.
+ //
+ // Crop region dimensions must be equal or smaller than input `buffer`
+ // dimensions.
+ absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1,
+ FrameBuffer* output_buffer) override;
+
+ // Resizes `buffer` to the size of the given `output_buffer`.
+ absl::Status Resize(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) override;
+
+ // Rotates `buffer` counter-clockwise by the given `angle_deg` (in degrees).
+ //
+ // The given angle must be a multiple of 90 degrees.
+ absl::Status Rotate(const FrameBuffer& buffer, int angle_deg,
+ FrameBuffer* output_buffer) override;
+
+ // Flips `buffer` horizontally.
+ absl::Status FlipHorizontally(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) override;
+
+ // Flips `buffer` vertically.
+ absl::Status FlipVertically(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) override;
+
+ // Converts `buffer`'s format to the format of the given `output_buffer`.
+ //
+ // Grayscale format cannot be converted to other formats.
+ absl::Status Convert(const FrameBuffer& buffer,
+ FrameBuffer* output_buffer) override;
+};
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_LIBYUV_FRAME_BUFFER_UTILS_H_
diff --git a/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc b/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
new file mode 100644
index 00000000..773ab76f
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc
@@ -0,0 +1,225 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
+
+#include <cmath>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+namespace {
+
+using ::absl::StatusCode;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+
+// Used to prevent log(<=0.0) in ClampedLog() calls.
+constexpr float kLogScoreMinimum = 1e-16;
+
+// Returns the following, depending on x:
+// x => threshold: log(x)
+// x < threshold: 2 * log(thresh) - log(2 * thresh - x)
+// This form (a) is anti-symmetric about the threshold and (b) has continuous
+// value and first derivative. This is done to prevent taking the log of values
+// close to 0 which can lead to floating point errors and is better than simple
+// clamping since it preserves order for scores less than the threshold.
+float ClampedLog(float x, float threshold) {
+ if (x < threshold) {
+ return 2.0 * std::log(static_cast<double>(threshold)) -
+ log(2.0 * threshold - x);
+ }
+ return std::log(static_cast<double>(x));
+}
+
+// Applies the specified score transformation to the provided score.
+// Currently supports the following,
+// IDENTITY : f(x) = x
+// LOG : f(x) = log(x)
+// INVERSE_LOGISTIC : f(x) = log(x) - log(1-x)
+float ApplyScoreTransformation(float score, const ScoreTransformation& type) {
+ switch (type) {
+ case ScoreTransformation::kIDENTITY:
+ return score;
+ case ScoreTransformation::kINVERSE_LOGISTIC:
+ return (ClampedLog(score, kLogScoreMinimum) -
+ ClampedLog(1.0 - score, kLogScoreMinimum));
+ case ScoreTransformation::kLOG:
+ return ClampedLog(score, kLogScoreMinimum);
+ }
+}
+
+// Builds a single Sigmoid from the label name and associated CSV file line.
+StatusOr<Sigmoid> SigmoidFromLabelAndLine(absl::string_view label,
+ absl::string_view line) {
+ std::vector<absl::string_view> str_params = absl::StrSplit(line, ',');
+ if (str_params.size() != 3 && str_params.size() != 4) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Expected 3 or 4 parameters per line in score "
+ "calibration file, got %d.",
+ str_params.size()),
+ TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError);
+ }
+ std::vector<float> float_params(4);
+ for (int i = 0; i < str_params.size(); ++i) {
+ if (!absl::SimpleAtof(str_params[i], &float_params[i])) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Could not parse score calibration parameter as float: %s.",
+ str_params[i]),
+ TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError);
+ }
+ }
+ Sigmoid sigmoid;
+ sigmoid.label = std::string(label);
+ sigmoid.scale = float_params[0];
+ sigmoid.slope = float_params[1];
+ sigmoid.offset = float_params[2];
+ if (str_params.size() == 4) {
+ sigmoid.min_uncalibrated_score = float_params[3];
+ }
+ return sigmoid;
+}
+
+// Converts a tflite::ScoreTransformationType to its
+// tflite::task::vision::ScoreTransformation equivalent.
+ScoreTransformation ConvertScoreTransformationType(
+ tflite::ScoreTransformationType type) {
+ switch (type) {
+ case tflite::ScoreTransformationType_IDENTITY:
+ return ScoreTransformation::kIDENTITY;
+ case tflite::ScoreTransformationType_LOG:
+ return ScoreTransformation::kLOG;
+ case tflite::ScoreTransformationType_INVERSE_LOGISTIC:
+ return ScoreTransformation::kINVERSE_LOGISTIC;
+ }
+}
+
+} // namespace
+
+std::ostream& operator<<(std::ostream& os, const Sigmoid& s) {
+ os << s.label << "," << s.slope << "," << s.offset << "," << s.scale;
+ if (s.min_uncalibrated_score.has_value()) {
+ os << "," << s.min_uncalibrated_score.value();
+ }
+ return os;
+}
+
+ScoreCalibration::ScoreCalibration() {}
+ScoreCalibration::~ScoreCalibration() {}
+
+absl::Status ScoreCalibration::InitializeFromParameters(
+ const SigmoidCalibrationParameters& params) {
+ sigmoid_parameters_ = std::move(params);
+ // Fill in the map from label -> sigmoid.
+ sigmoid_parameters_map_.clear();
+ for (const auto& sigmoid : sigmoid_parameters_.sigmoid) {
+ sigmoid_parameters_map_.insert_or_assign(sigmoid.label, sigmoid);
+ }
+ return absl::OkStatus();
+}
+
+float ScoreCalibration::ComputeCalibratedScore(const std::string& label,
+ float uncalibrated_score) const {
+ absl::optional<Sigmoid> sigmoid = FindSigmoidParameters(label);
+ if (!sigmoid.has_value() ||
+ (sigmoid.value().min_uncalibrated_score.has_value() &&
+ uncalibrated_score < sigmoid.value().min_uncalibrated_score.value())) {
+ return sigmoid_parameters_.default_score;
+ }
+
+ float transformed_score = ApplyScoreTransformation(
+ uncalibrated_score, sigmoid_parameters_.score_transformation);
+ float scale_shifted_score =
+ transformed_score * sigmoid.value().slope + sigmoid.value().offset;
+
+ // For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0
+ // and exp(x) / (1+exp(x)) when scale_shifted_score < 0.
+ if (scale_shifted_score >= 0.0) {
+ return sigmoid.value().scale /
+ (1.0 + std::exp(static_cast<double>(-scale_shifted_score)));
+ } else {
+ float score_exp = std::exp(static_cast<double>(scale_shifted_score));
+ return sigmoid.value().scale * score_exp / (1.0 + score_exp);
+ }
+}
+
+absl::optional<Sigmoid> ScoreCalibration::FindSigmoidParameters(
+ const std::string& label) const {
+ auto it = sigmoid_parameters_map_.find(label);
+ if (it != sigmoid_parameters_map_.end()) {
+ return it->second;
+ } else if (sigmoid_parameters_.default_sigmoid.has_value()) {
+ return sigmoid_parameters_.default_sigmoid.value();
+ }
+ return absl::nullopt;
+}
+
+StatusOr<SigmoidCalibrationParameters> BuildSigmoidCalibrationParams(
+ const tflite::ScoreCalibrationOptions& score_calibration_options,
+ absl::string_view score_calibration_file,
+ const std::vector<LabelMapItem>& label_map_items) {
+ // Split file lines and perform sanity checks.
+ if (score_calibration_file.empty()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Expected non-empty score calibration file.");
+ }
+ std::vector<absl::string_view> lines =
+ absl::StrSplit(score_calibration_file, '\n');
+ if (label_map_items.size() != lines.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Mismatch between number of labels (%d) and score "
+ "calibration parameters (%d).",
+ label_map_items.size(), lines.size()),
+ TfLiteSupportStatus::kMetadataNumLabelsMismatchError);
+ }
+ // Initialize SigmoidCalibrationParameters with its class-agnostic parameters.
+ SigmoidCalibrationParameters sigmoid_params = {};
+ sigmoid_params.score_transformation = ConvertScoreTransformationType(
+ score_calibration_options.score_transformation());
+ sigmoid_params.default_score = score_calibration_options.default_score();
+ std::vector<Sigmoid> sigmoid_vector;
+ // Fill sigmoids for each class with parameters in the file.
+ for (int i = 0; i < label_map_items.size(); ++i) {
+ if (lines[i].empty()) {
+ continue;
+ }
+ ASSIGN_OR_RETURN(Sigmoid sigmoid, SigmoidFromLabelAndLine(
+ label_map_items[i].name, lines[i]));
+ sigmoid_vector.emplace_back(std::move(sigmoid));
+ }
+ sigmoid_params.sigmoid = std::move(sigmoid_vector);
+
+ return sigmoid_params;
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h b/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
new file mode 100644
index 00000000..c3f0bf8a
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h
@@ -0,0 +1,146 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_
+
+#include <iostream>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Sigmoid structure.
+struct Sigmoid {
+ Sigmoid() : scale(1.0) {}
+ Sigmoid(std::string label, float slope, float offset, float scale = 1.0,
+ absl::optional<float> min_uncalibrated_score = absl::nullopt)
+ : label(label),
+ slope(slope),
+ offset(offset),
+ scale(scale),
+ min_uncalibrated_score(min_uncalibrated_score) {}
+
+ bool operator==(const Sigmoid& other) const {
+ return label == other.label && slope == other.slope &&
+ offset == other.offset && scale == other.scale &&
+ min_uncalibrated_score == other.min_uncalibrated_score;
+ }
+
+ // Unique label corresponding to the sigmoid parameters.
+ std::string label;
+ float slope;
+ float offset;
+ float scale;
+ absl::optional<float> min_uncalibrated_score;
+};
+
+std::ostream& operator<<(std::ostream& os, const Sigmoid& s);
+
+// Transformation function to use for computing transformation scores.
+enum class ScoreTransformation {
+ kIDENTITY, // f(x) = x
+ kLOG, // f(x) = log(x)
+ kINVERSE_LOGISTIC // f(x) = log(x) - log(1 - x)
+};
+
+// Sigmoid calibration parameters.
+struct SigmoidCalibrationParameters {
+ SigmoidCalibrationParameters()
+ : default_score(0.0),
+ score_transformation(ScoreTransformation::kIDENTITY) {}
+ explicit SigmoidCalibrationParameters(
+ std::vector<Sigmoid> sigmoid,
+ ScoreTransformation score_transformation = ScoreTransformation::kIDENTITY,
+ absl::optional<Sigmoid> default_sigmoid = absl::nullopt,
+ float default_score = 0.0)
+ : sigmoid(sigmoid),
+ default_sigmoid(default_sigmoid),
+ default_score(default_score),
+ score_transformation(score_transformation) {}
+ // A vector of Sigmoid associated to the ScoreCalibration instance.
+ std::vector<Sigmoid> sigmoid;
+ // If set, this sigmoid will be applied to any non-matching labels.
+ absl::optional<Sigmoid> default_sigmoid;
+ // The default score for non-matching labels. Only used if default_sigmoid
+ // isn't set.
+ float default_score;
+ // Function for computing a transformation score prior to sigmoid fitting.
+ ScoreTransformation score_transformation;
+};
+
+// This class is used to calibrate predicted scores so that scores are
+// comparable across labels. Depending on the particular calibration parameters
+// being used, the calibrated scores can also be approximately interpreted as a
+// likelihood of being correct. For a given TF Lite model, such parameters are
+// typically obtained from TF Lite Metadata (see ScoreCalibrationOptions).
+class ScoreCalibration {
+ public:
+ ScoreCalibration();
+ ~ScoreCalibration();
+
+ // Transfers input parameters and construct a label to sigmoid map.
+ absl::Status InitializeFromParameters(
+ const SigmoidCalibrationParameters& params);
+
+ // Returns a calibrated score given a label string and uncalibrated score. The
+ // calibrated score will be in the range [0.0, 1.0] and can loosely be
+ // interpreted as a likelihood of the label being correct.
+ float ComputeCalibratedScore(const std::string& label,
+ float uncalibrated_score) const;
+
+ private:
+ // Finds the sigmoid parameters corresponding to the provided label.
+ absl::optional<Sigmoid> FindSigmoidParameters(const std::string& label) const;
+
+ // Parameters for internal states.
+ SigmoidCalibrationParameters sigmoid_parameters_;
+
+ // Maps label strings to the particular sigmoid stored in sigmoid_parameters_.
+ absl::flat_hash_map<std::string, Sigmoid> sigmoid_parameters_map_;
+};
+
+// Builds SigmoidCalibrationParameters using data obtained from TF Lite Metadata
+// (see ScoreCalibrationOptions in metadata schema).
+//
+// The provided `score_calibration_file` represents the contents of the score
+// calibration associated file (TENSOR_AXIS_SCORE_CALIBRATION), i.e. one set of
+// parameters (scale, slope, etc) per line. Each line must be in 1:1
+// correspondence with `label_map_items`, so as to associate each sigmoid to its
+// corresponding label name. Returns an error if no valid parameters could be
+// built (e.g. malformed parameters).
+tflite::support::StatusOr<SigmoidCalibrationParameters>
+BuildSigmoidCalibrationParams(
+ const tflite::ScoreCalibrationOptions& score_calibration_options,
+ absl::string_view score_calibration_file,
+ const std::vector<LabelMapItem>& label_map_items);
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_
diff --git a/tensorflow_lite_support/cc/text/tokenizers/BUILD b/tensorflow_lite_support/cc/text/tokenizers/BUILD
new file mode 100644
index 00000000..3ad8da2f
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/BUILD
@@ -0,0 +1,191 @@
+# This package contains C++ support libraries that Java libraries can invoke.
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+load(
+ "@org_tensorflow//tensorflow/lite:build_def.bzl",
+ "tflite_copts",
+ "tflite_jni_binary",
+)
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "tokenizer",
+ hdrs = [
+ "tokenizer.h",
+ ],
+ deps = [
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "tokenizer_jni_lib",
+ srcs = [
+ "tokenizer_jni_lib.cc",
+ ],
+ hdrs = [
+ "tokenizer_jni_lib.h",
+ ],
+ deps = [
+ ":tokenizer",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "@org_tensorflow//tensorflow/lite/java/jni",
+ ],
+)
+
+cc_library(
+ name = "bert_tokenizer",
+ srcs = [
+ "bert_tokenizer.cc",
+ ],
+ hdrs = [
+ "bert_tokenizer.h",
+ ],
+ deps = [
+ ":tokenizer",
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/utils:common_utils",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_googlesource_code_re2//:re2",
+ "@org_tensorflow_text//tensorflow_text/core/kernels:regex_split",
+ "@org_tensorflow_text//tensorflow_text/core/kernels:wordpiece_tokenizer",
+ ],
+)
+
+cc_library(
+ name = "bert_tokenizer_jni_lib",
+ srcs = [
+ "bert_tokenizer_jni.cc",
+ ],
+ copts = tflite_copts(),
+ linkopts = [
+ "-lm",
+ "-ldl",
+ ],
+ deps = [
+ ":bert_tokenizer",
+ ":tokenizer_jni_lib",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "@com_google_absl//absl/memory",
+ "@org_tensorflow//tensorflow/lite/java/jni",
+ ],
+ alwayslink = 1,
+)
+
+tflite_jni_binary(
+ name = "libbert_tokenizer_jni.so",
+ deps = [
+ ":bert_tokenizer_jni_lib",
+ ],
+)
+
+cc_library(
+ name = "bert_tokenizer_runtime",
+ srcs = ["libbert_tokenizer_jni.so"],
+ alwayslink = 1,
+)
+
+android_library(
+ name = "bert_tokenizer_jni",
+ custom_package = "org.tensorflow.lite.support.text",
+ manifest = "DummyManifest.xml",
+ resource_files = [],
+ deps = [
+ ":bert_tokenizer_runtime", # build_cleaner: skip
+ ],
+)
+
+cc_library(
+ name = "sentencepiece_tokenizer",
+ hdrs = [
+ "sentencepiece_tokenizer.h",
+ ],
+ deps = [
+ ":tokenizer",
+ "@com_google_sentencepiece//src:sentencepiece_processor",
+ ],
+)
+
+cc_library(
+ name = "sentencepiece_jni_lib",
+ srcs = [
+ "sentencepiece_jni.cc",
+ ],
+ copts = tflite_copts(),
+ linkopts = [
+ "-lm",
+ "-ldl",
+ ],
+ deps = [
+ ":sentencepiece_tokenizer",
+ ":tokenizer_jni_lib",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@org_tensorflow//tensorflow/lite/java/jni",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "sentencepiece_runtime",
+ srcs = ["libsentencepiece_jni.so"],
+ alwayslink = 1,
+)
+
+tflite_jni_binary(
+ name = "libsentencepiece_jni.so",
+ deps = [
+ ":sentencepiece_jni_lib",
+ ],
+)
+
+android_library(
+ name = "sentencepiece_jni",
+ custom_package = "org.tensorflow.lite.support.text",
+ manifest = "DummyManifest.xml",
+ resource_files = [],
+ deps = [
+ ":sentencepiece_runtime", # build_cleaner: skip
+ ],
+)
+
+cc_library(
+ name = "tokenizer_utils",
+ srcs = ["tokenizer_utils.cc"],
+ hdrs = [
+ "tokenizer_utils.h",
+ ],
+ deps = [
+ ":bert_tokenizer",
+ ":regex_tokenizer",
+ ":sentencepiece_tokenizer",
+ ":tokenizer",
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "//tensorflow_lite_support/metadata/cc:metadata_extractor",
+ "@com_google_absl//absl/status",
+ ],
+)
+
+cc_library(
+ name = "regex_tokenizer",
+ srcs = [
+ "regex_tokenizer.cc",
+ ],
+ hdrs = [
+ "regex_tokenizer.h",
+ ],
+ deps = [
+ ":tokenizer",
+ "//tensorflow_lite_support/cc/utils:common_utils",
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/strings",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
diff --git a/tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml b/tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml
new file mode 100644
index 00000000..ff025072
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml
@@ -0,0 +1,19 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="org.tensorflow.lite.support.text">
+</manifest>
diff --git a/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
new file mode 100644
index 00000000..aeb887c6
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc
@@ -0,0 +1,108 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
+
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+
+namespace tflite {
+namespace support {
+namespace text {
+namespace tokenizer {
+
+FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece(
+ const std::vector<std::string>& vocab)
+ : vocab_{vocab} {
+ for (int i = 0; i < vocab_.size(); ++i) {
+ index_map_[vocab_[i]] = i;
+ }
+}
+
+tensorflow::text::LookupStatus FlatHashMapBackedWordpiece::Contains(
+ absl::string_view key, bool* value) const {
+ *value = index_map_.contains(key);
+ return tensorflow::text::LookupStatus();
+}
+
+bool FlatHashMapBackedWordpiece::LookupId(const absl::string_view key,
+ int* result) const {
+ auto it = index_map_.find(key);
+ if (it == index_map_.end()) {
+ return false;
+ }
+ *result = it->second;
+ return true;
+}
+
+bool FlatHashMapBackedWordpiece::LookupWord(int vocab_id,
+ absl::string_view* result) const {
+ if (vocab_id >= vocab_.size() || vocab_id < 0) {
+ return false;
+ }
+ *result = vocab_[vocab_id];
+ return true;
+}
+
+TokenizerResult BertTokenizer::Tokenize(const std::string& input) {
+ return TokenizeWordpiece(input);
+}
+
+WordpieceTokenizerResult BertTokenizer::TokenizeWordpiece(
+ const std::string& input) {
+ WordpieceTokenizerResult result;
+ std::vector<std::string>& subwords = result.subwords;
+ std::vector<int>& wp_absolute_begin_offset = result.wp_begin_offset;
+ std::vector<int>& wp_absolute_end_offset = result.wp_end_offset;
+
+ std::vector<absl::string_view> tokens;
+ std::vector<int64> begin_offsets;
+ std::vector<int64> end_offsets;
+
+ // Run through tokenize function
+ tensorflow::text::RegexSplit(input, delim_re_, true, include_delim_re_,
+ &tokens, &begin_offsets, &end_offsets);
+
+ for (int token_index = 0; token_index < tokens.size(); token_index++) {
+ auto& token = tokens[token_index];
+ int num_word_pieces = 0;
+ tensorflow::text::LookupStatus status = WordpieceTokenize(
+ token, options_.max_bytes_per_token, options_.max_chars_per_subtoken,
+ options_.suffix_indicator, options_.use_unknown_token,
+ options_.unknown_token, options_.split_unknown_chars, &vocab_,
+ &subwords, &wp_absolute_begin_offset, &wp_absolute_end_offset,
+ &num_word_pieces);
+
+ result.row_lengths.emplace_back(num_word_pieces);
+ // for the last num_word_pieces added into wp_absolute_begin_offset and
+ // wp_absolute_end_offset, offset them with begin_offsets[token_index]
+ int absolute_offset_size = wp_absolute_begin_offset.size();
+ for (int i = num_word_pieces; i > 0; i--) {
+ wp_absolute_begin_offset[absolute_offset_size - i] +=
+ begin_offsets[token_index];
+ wp_absolute_end_offset[absolute_offset_size - i] +=
+ begin_offsets[token_index];
+ }
+ if (!status.success) {
+ return result;
+ }
+ }
+
+ return result;
+}
+
+} // namespace tokenizer
+} // namespace text
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
new file mode 100644
index 00000000..14a006c2
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h
@@ -0,0 +1,149 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "absl/container/flat_hash_map.h"
+#include "re2/re2.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+#include "tensorflow_lite_support/cc/utils/common_utils.h"
+#include "tensorflow_text/core/kernels/regex_split.h"
+#include "tensorflow_text/core/kernels/wordpiece_tokenizer.h"
+
+namespace tflite {
+namespace support {
+namespace text {
+namespace tokenizer {
+
+constexpr char kDefaultDelimRe[] =
+ R"((\s+|[!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))";
+constexpr char kDefaultIncludeDelimRe[] =
+ R"(([!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))";
+constexpr int kDefaultMaxBytesPerToken = 100;
+constexpr int kDefaultMaxCharsPerSubToken = 100;
+constexpr char kDefaultSuffixIndicator[] = "##";
+constexpr bool kDefaultUseUnknownToken = true;
+constexpr char kDefaultUnknownToken[] = "[UNK]";
+constexpr bool kDefaultSplitUnknownChars = false;
+
+// Result of wordpiece tokenization including subwords and offsets.
+// Example:
+// input: tokenize me please
+// subwords: token ##ize me plea ##se
+// wp_begin_offset: [0, 5, 9, 12, 16]
+// wp_end_offset: [ 5, 8, 11, 16, 18]
+// row_lengths: [2, 1, 1]
+struct WordpieceTokenizerResult : TokenizerResult {
+ std::vector<int> wp_begin_offset;
+ std::vector<int> wp_end_offset;
+ std::vector<int> row_lengths;
+};
+// Options to create a BertTokenizer.
+struct BertTokenizerOptions {
+ int max_bytes_per_token = kDefaultMaxBytesPerToken;
+ int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken;
+ std::string suffix_indicator = kDefaultSuffixIndicator;
+ bool use_unknown_token = kDefaultUseUnknownToken;
+ std::string unknown_token = kDefaultUnknownToken;
+ bool split_unknown_chars = kDefaultSplitUnknownChars;
+ std::string delim_str = kDefaultDelimRe;
+ std::string include_delim_str = kDefaultIncludeDelimRe;
+};
+
+// A flat-hash-map based implementation of WordpieceVocab, used in
+// BertTokenizer to invoke tensorflow::text::WordpieceTokenize within.
+class FlatHashMapBackedWordpiece : public tensorflow::text::WordpieceVocab {
+ public:
+ explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab);
+
+ tensorflow::text::LookupStatus Contains(absl::string_view key,
+ bool* value) const override;
+ bool LookupId(absl::string_view key, int* result) const;
+ bool LookupWord(int vocab_id, absl::string_view* result) const;
+ int VocabularySize() const { return vocab_.size(); }
+
+ private:
+ // All words indexed position in vocabulary file.
+ std::vector<std::string> vocab_;
+ absl::flat_hash_map<absl::string_view, int> index_map_;
+};
+
+// Wordpiece tokenizer for bert models. Initialized with a vocab file or vector.
+class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer {
+ public:
+ // Initialize the tokenizer from vocab vector and tokenizer configs.
+ explicit BertTokenizer(const std::vector<std::string>& vocab,
+ const BertTokenizerOptions& options = {})
+ : vocab_{FlatHashMapBackedWordpiece(vocab)},
+ options_{options},
+ delim_re_{options.delim_str},
+ include_delim_re_{options.include_delim_str} {}
+
+ // Initialize the tokenizer from file path to vocab and tokenizer configs.
+ explicit BertTokenizer(const std::string& path_to_vocab,
+ const BertTokenizerOptions& options = {})
+ : BertTokenizer(utils::LoadVocabFromFile(path_to_vocab), options) {}
+
+ // Initialize the tokenizer from buffer and size of vocab and tokenizer
+ // configs.
+ BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size,
+ const BertTokenizerOptions& options = {})
+ : BertTokenizer(
+ utils::LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size),
+ options) {}
+
+ // Perform tokenization, return tokenized results containing the subwords.
+ TokenizerResult Tokenize(const std::string& input) override;
+
+ // Perform tokenization, return wordpiece-specific tokenized result including
+ // subwords and offsets
+ WordpieceTokenizerResult TokenizeWordpiece(const std::string& input);
+
+ // Check if a certain key is included in the vocab.
+ tensorflow::text::LookupStatus Contains(const absl::string_view key,
+ bool* value) const {
+ return vocab_.Contains(key, value);
+ }
+
+ // Find the id of a wordpiece.
+ bool LookupId(absl::string_view key, int* result) const override {
+ return vocab_.LookupId(key, result);
+ }
+
+ // Find the wordpiece from an id.
+ bool LookupWord(int vocab_id, absl::string_view* result) const override {
+ return vocab_.LookupWord(vocab_id, result);
+ }
+
+ int VocabularySize() const { return vocab_.VocabularySize(); }
+
+ private:
+ tflite::support::text::tokenizer::FlatHashMapBackedWordpiece vocab_;
+ BertTokenizerOptions options_;
+ RE2 delim_re_;
+ RE2 include_delim_re_;
+};
+
+} // namespace tokenizer
+} // namespace text
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_
diff --git a/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
new file mode 100644
index 00000000..442d06ec
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc
@@ -0,0 +1,87 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+
+namespace tflite {
+namespace support {
+
+using ::tflite::support::text::tokenizer::BertTokenizer;
+using ::tflite::support::text::tokenizer::BertTokenizerOptions;
+using ::tflite::support::utils::StringListToVector;
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResource( // NOLINT
+ JNIEnv* env, jobject thiz, jobject vocab_list, jint max_bytes_per_token,
+ jint max_chars_per_sub_token, jstring jsuffix_indicator,
+ jboolean use_unknown_token, jstring junknown_token,
+ jboolean split_unknown_chars) {
+ // Convert java.util.List<String> into std::vector<string>
+ std::vector<std::string> vocab = StringListToVector(env, vocab_list);
+
+ // Convert jstrings to std::string
+ const char* raw_suffix_indicator =
+ env->GetStringUTFChars(jsuffix_indicator, JNI_FALSE);
+ std::string suffix_indicator(raw_suffix_indicator);
+
+ const char* raw_unknown_token =
+ env->GetStringUTFChars(junknown_token, JNI_FALSE);
+ std::string unknown_token(raw_unknown_token);
+
+ auto handle = absl::make_unique<BertTokenizer>(
+ vocab, BertTokenizerOptions{
+ .max_bytes_per_token = max_bytes_per_token,
+ .max_chars_per_subtoken = max_chars_per_sub_token,
+ .suffix_indicator = suffix_indicator,
+ .use_unknown_token = static_cast<bool>(use_unknown_token),
+ .unknown_token = unknown_token,
+ .split_unknown_chars = static_cast<bool>(split_unknown_chars),
+ .delim_str = text::tokenizer::kDefaultDelimRe,
+ .include_delim_str = text::tokenizer::kDefaultIncludeDelimRe});
+
+ env->ReleaseStringUTFChars(jsuffix_indicator, raw_suffix_indicator);
+ env->ReleaseStringUTFChars(junknown_token, raw_unknown_token);
+
+ return reinterpret_cast<jlong>(handle.release());
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeUnloadResource( // NOLINT
+ JNIEnv* env, jobject thiz, jlong handle) {
+ delete reinterpret_cast<BertTokenizer*>(handle);
+ return 0;
+}
+
+extern "C" JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeTokenize(
+ JNIEnv* env, jobject thiz, jlong handle, jstring jtext) {
+ return nativeTokenize(env, handle, jtext);
+}
+
+extern "C" JNIEXPORT jintArray JNICALL
+Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeConvertTokensToIds( // NOLINT
+ JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) {
+ return nativeConvertTokensToIds(env, handle, jtokens);
+}
+
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
new file mode 100644
index 00000000..38aff880
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc
@@ -0,0 +1,125 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
+
+#include <iostream>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/substitute.h"
+#include "tensorflow_lite_support/cc/utils/common_utils.h"
+namespace tflite {
+namespace support {
+namespace text {
+namespace tokenizer {
+
+namespace {
+constexpr char kStart[] = "<START>";
+constexpr char kPad[] = "<PAD>";
+constexpr char kUnknown[] = "<UNKNOWN>";
+
+void buildIndexTokenMap(
+ const absl::node_hash_map<std::string, int>& token_index_map,
+ absl::node_hash_map<int, absl::string_view>* index_token_map) {
+ for (const auto& token : token_index_map) {
+ (*index_token_map)[token.second] = token.first;
+ }
+}
+
+} // namespace
+
+// RE2::FindAndConsume requires the delim_re_ to have a matching group in order
+// to capture the matched delimiter length. Surround the regex with a
+// parenthesis to create a matching group, it's fine if the regex is already
+// surrounded by parenthesis.
+RegexTokenizer::RegexTokenizer(const std::string& regex_pattern,
+ const std::string& path_to_vocab)
+ : delim_re_{absl::Substitute("($0)", regex_pattern)},
+ token_index_map_{utils::LoadVocabAndIndexFromFile(path_to_vocab)} {
+ buildIndexTokenMap(token_index_map_, &index_token_map_);
+}
+
+RegexTokenizer::RegexTokenizer(const std::string& regex_pattern,
+ const char* vocab_buffer_data,
+ size_t vocab_buffer_size)
+ : delim_re_{absl::Substitute("($0)", regex_pattern)},
+ token_index_map_{utils::LoadVocabAndIndexFromBuffer(vocab_buffer_data,
+ vocab_buffer_size)} {
+ buildIndexTokenMap(token_index_map_, &index_token_map_);
+}
+
+TokenizerResult RegexTokenizer::Tokenize(const std::string& input) {
+ absl::string_view leftover(input.data());
+ absl::string_view last_end = leftover;
+
+ TokenizerResult result;
+
+ // Keep looking for split points until we have reached the end of the input.
+ absl::string_view extracted_delim_token;
+ while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) {
+ absl::string_view token(last_end.data(),
+ extracted_delim_token.data() - last_end.data());
+ bool has_non_empty_token = token.length() > 0;
+
+ last_end = leftover;
+
+ // Mark the end of the previous token, only if there was something.
+ if (has_non_empty_token) {
+ result.subwords.push_back(std::string(token));
+ }
+ }
+
+ // Close the last token.
+ if (!leftover.empty()) {
+ result.subwords.push_back(std::string(leftover));
+ }
+
+ return result;
+}
+
+bool RegexTokenizer::LookupId(absl::string_view key, int* result) const {
+ auto it = token_index_map_.find(key);
+ if (it == token_index_map_.end()) {
+ return false;
+ }
+ *result = it->second;
+ return true;
+}
+
+bool RegexTokenizer::LookupWord(int vocab_id, absl::string_view* result) const {
+ auto it = index_token_map_.find(vocab_id);
+ if (it == index_token_map_.end()) {
+ return false;
+ }
+ *result = it->second;
+ return true;
+}
+
+bool RegexTokenizer::GetStartToken(int* start_token) {
+ return LookupId(kStart, start_token);
+}
+
+bool RegexTokenizer::GetPadToken(int* pad_token) {
+ return LookupId(kPad, pad_token);
+}
+
+bool RegexTokenizer::GetUnknownToken(int* unknown_token) {
+ return LookupId(kUnknown, unknown_token);
+}
+
+} // namespace tokenizer
+} // namespace text
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h
new file mode 100644
index 00000000..c53ae496
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h
@@ -0,0 +1,59 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_
+
+#include "absl/container/node_hash_map.h"
+#include "re2/re2.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+
+namespace tflite {
+namespace support {
+namespace text {
+namespace tokenizer {
+
+// Tokenizer to load a vocabulary and split text by regular expressions.
+class RegexTokenizer : public Tokenizer {
+ public:
+ explicit RegexTokenizer(const std::string& regex_pattern,
+ const std::string& path_to_vocab);
+
+ explicit RegexTokenizer(const std::string& regex_pattern,
+ const char* vocab_buffer_data,
+ size_t vocab_buffer_size);
+
+ TokenizerResult Tokenize(const std::string& input) override;
+
+ bool LookupId(absl::string_view key, int* result) const override;
+
+ bool LookupWord(int vocab_id, absl::string_view* result) const override;
+
+ bool GetStartToken(int* start_token);
+ bool GetPadToken(int* pad_token);
+ bool GetUnknownToken(int* unknown_token);
+
+ private:
+ RE2 delim_re_;
+ absl::node_hash_map<std::string, int> token_index_map_;
+ absl::node_hash_map<int, absl::string_view> index_token_map_;
+};
+
+} // namespace tokenizer
+} // namespace text
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_
diff --git a/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
new file mode 100644
index 00000000..88065e20
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc
@@ -0,0 +1,64 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+
+#include <cstring>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+
+namespace tflite {
+namespace support {
+
+using ::tflite::support::text::tokenizer::SentencePieceTokenizer;
+using ::tflite::support::utils::GetMappedFileBuffer;
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLoadResource( // NOLINT
+ JNIEnv* env, jobject obj, jobject model_buffer) {
+ auto model = GetMappedFileBuffer(env, model_buffer);
+ auto handle =
+ absl::make_unique<SentencePieceTokenizer>(model.data(), model.size());
+ return reinterpret_cast<jlong>(handle.release());
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeUnloadResource( // NOLINT
+ JNIEnv* env, jobject obj, jlong handle) {
+ delete reinterpret_cast<SentencePieceTokenizer*>(handle);
+ return 0;
+}
+
+extern "C" JNIEXPORT jobjectArray JNICALL
+Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeTokenize( // NOLINT
+ JNIEnv* env, jobject thiz, jlong handle, jstring jtext) {
+ return nativeTokenize(env, handle, jtext);
+}
+
+extern "C" JNIEXPORT jintArray JNICALL
+Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeConvertTokensToIds( // NOLINT
+ JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) {
+ return nativeConvertTokensToIds(env, handle, jtokens);
+}
+
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h
new file mode 100644
index 00000000..ed5d3da7
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h
@@ -0,0 +1,74 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "src/sentencepiece_processor.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+
+namespace tflite {
+namespace support {
+namespace text {
+namespace tokenizer {
+
+// SentencePiece tokenizer. Initialized with a model file.
+class SentencePieceTokenizer : public Tokenizer {
+ public:
+ // Initialize the SentencePiece tokenizer from model file path.
+ explicit SentencePieceTokenizer(const std::string& path_to_model) {
+ CHECK_OK(sp_.Load(path_to_model));
+ }
+
+ explicit SentencePieceTokenizer(const char* spmodel_buffer_data,
+ size_t spmodel_buffer_size) {
+ absl::string_view buffer_binary(spmodel_buffer_data, spmodel_buffer_size);
+ CHECK_OK(sp_.LoadFromSerializedProto(buffer_binary));
+ }
+
+ // Perform tokenization, return tokenized results.
+ TokenizerResult Tokenize(const std::string& input) override {
+ TokenizerResult result;
+ std::vector<std::string>& subwords = result.subwords;
+ CHECK_OK(sp_.Encode(input, &subwords));
+ return result;
+ }
+
+ // Find the id of a string token.
+ bool LookupId(absl::string_view key, int* result) const override {
+ *result = sp_.PieceToId(key);
+ return true;
+ }
+
+ // Find the string token of an id.
+ bool LookupWord(int vocab_id, absl::string_view* result) const override {
+ *result = sp_.IdToPiece(vocab_id);
+ return true;
+ }
+
+ private:
+ sentencepiece::SentencePieceProcessor sp_;
+};
+
+} // namespace tokenizer
+} // namespace text
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_
diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h
new file mode 100644
index 00000000..c7545064
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h
@@ -0,0 +1,55 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_H_
+
+#include <fstream>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+
+namespace tflite {
+namespace support {
+namespace text {
+namespace tokenizer {
+
+struct TokenizerResult {
+ std::vector<std::string> subwords;
+};
+
+// Interface of general tokenizer.
+class Tokenizer {
+ public:
+ // Perform tokenization to get tokenized results.
+ virtual TokenizerResult Tokenize(const std::string& input) = 0;
+
+ // Find the id of a string token.
+ virtual bool LookupId(absl::string_view key, int* result) const = 0;
+
+ // Find the string token from an id.
+ virtual bool LookupWord(int vocab_id, absl::string_view* result) const = 0;
+
+ // Destructor.
+ virtual ~Tokenizer() = default;
+};
+
+} // namespace tokenizer
+} // namespace text
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_H_
diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
new file mode 100644
index 00000000..a72523be
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc
@@ -0,0 +1,86 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h"
+
+namespace tflite {
+namespace support {
+
+using ::tflite::support::text::tokenizer::Tokenizer;
+using ::tflite::support::text::tokenizer::TokenizerResult;
+using ::tflite::support::utils::CheckNotNull;
+using ::tflite::support::utils::JStringToString;
+using ::tflite::support::utils::kIllegalStateException;
+
+jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext) {
+ if (handle == 0) {
+ env->ThrowNew(env->FindClass(kIllegalStateException),
+ "Vocab not initialized!");
+ return nullptr;
+ }
+
+ Tokenizer* tokenizer = reinterpret_cast<Tokenizer*>(handle);
+
+ // Get the tokenization results.
+ const TokenizerResult tokenize_result =
+ tokenizer->Tokenize(JStringToString(env, jtext));
+ std::vector<std::string> subwords = tokenize_result.subwords;
+
+ jclass string_class = CheckNotNull(env, env->FindClass("java/lang/String"));
+ jobjectArray result = CheckNotNull(
+ env, env->NewObjectArray(subwords.size(), string_class, nullptr));
+
+ for (int i = 0; i < subwords.size(); ++i) {
+ jstring text = CheckNotNull(env, env->NewStringUTF(subwords[i].data()));
+ if (env->ExceptionCheck()) {
+ return nullptr;
+ }
+
+ env->SetObjectArrayElement(result, i, text);
+ }
+
+ return result;
+}
+
+jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle,
+ jobjectArray jtokens) {
+ if (handle == 0) {
+ env->ThrowNew(env->FindClass(kIllegalStateException),
+ "vocab not initialized!");
+ return nullptr;
+ }
+
+ Tokenizer* tokenizer = reinterpret_cast<Tokenizer*>(handle);
+
+ // Get the token ids.
+ const int count = env->GetArrayLength(jtokens);
+ jintArray result = env->NewIntArray(count);
+ jint* jid_ptr = env->GetIntArrayElements(result, nullptr);
+
+ for (int i = 0; i < count; i++) {
+ auto jstr =
+ reinterpret_cast<jstring>(env->GetObjectArrayElement(jtokens, i));
+ const char* token = env->GetStringUTFChars(jstr, JNI_FALSE);
+ int id;
+ tokenizer->LookupId(token, &id);
+ jid_ptr[i] = id;
+ env->ReleaseStringUTFChars(jstr, token);
+ }
+ env->ReleaseIntArrayElements(result, jid_ptr, 0);
+ return result;
+}
+
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
new file mode 100644
index 00000000..fc7285c6
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h
@@ -0,0 +1,36 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_JNI_LIB_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_JNI_LIB_H_
+
+#include <jni.h>
+
+#include <string>
+
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+
+namespace tflite {
+namespace support {
+
+jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext);
+
+jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle,
+ jobjectArray jtokens);
+
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_JNI_LIB_H_
diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
new file mode 100644
index 00000000..3e81c478
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc
@@ -0,0 +1,136 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
+
+#include "absl/status/status.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace support {
+namespace text {
+namespace tokenizer {
+
+
+using ::tflite::ProcessUnit;
+using ::tflite::SentencePieceTokenizerOptions;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+
+namespace {
+
+StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
+ associated_files,
+ const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
+ if (associated_files == nullptr || associated_files->size() < 1 ||
+ associated_files->Get(0)->name() == nullptr) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "Invalid vocab_file from input process unit.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+ ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
+ metadata_extractor->GetAssociatedFile(
+ associated_files->Get(0)->name()->str()));
+ return vocab_buffer;
+}
+} // namespace
+
+StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit(
+ const tflite::ProcessUnit* tokenizer_process_unit,
+ const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
+ if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "No metadata or input process unit found.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+ switch (tokenizer_process_unit->options_type()) {
+ case ProcessUnitOptions_BertTokenizerOptions: {
+ const tflite::BertTokenizerOptions* options =
+ tokenizer_process_unit->options_as<tflite::BertTokenizerOptions>();
+ ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
+ CheckAndLoadFirstAssociatedFile(options->vocab_file(),
+ metadata_extractor));
+ return absl::make_unique<BertTokenizer>(vocab_buffer.data(),
+ vocab_buffer.size());
+ }
+ case ProcessUnitOptions_SentencePieceTokenizerOptions: {
+ const tflite::SentencePieceTokenizerOptions* options =
+ tokenizer_process_unit->options_as<SentencePieceTokenizerOptions>();
+ ASSIGN_OR_RETURN(absl::string_view model_buffer,
+ CheckAndLoadFirstAssociatedFile(
+ options->sentencePiece_model(), metadata_extractor));
+ // TODO(b/160647204): Extract sentence piece model vocabulary
+ return absl::make_unique<SentencePieceTokenizer>(model_buffer.data(),
+ model_buffer.size());
+ }
+ case ProcessUnitOptions_RegexTokenizerOptions: {
+ const tflite::RegexTokenizerOptions* options =
+ tokenizer_process_unit->options_as<RegexTokenizerOptions>();
+ ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
+ CheckAndLoadFirstAssociatedFile(options->vocab_file(),
+ metadata_extractor));
+ if (options->delim_regex_pattern() == nullptr) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "Invalid delim_regex_pattern from input process unit.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+
+ std::unique_ptr<RegexTokenizer> regex_tokenizer =
+ absl::make_unique<RegexTokenizer>(
+ options->delim_regex_pattern()->str(), vocab_buffer.data(),
+ vocab_buffer.size());
+
+ int unknown_token_id = 0;
+ if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "RegexTokenizer doesn't have <UNKNOWN> token.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+
+ int pad_token_id = 0;
+ if (!regex_tokenizer->GetPadToken(&pad_token_id)) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInvalidArgument,
+ "RegexTokenizer doesn't have <PAD> token.",
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+
+ return regex_tokenizer;
+ }
+ default:
+ return CreateStatusWithPayload(
+ absl::StatusCode::kNotFound,
+ absl::StrCat("Incorrect options_type:",
+ tokenizer_process_unit->options_type()),
+ TfLiteSupportStatus::kMetadataInvalidTokenizerError);
+ }
+}
+
+} // namespace tokenizer
+} // namespace text
+} // namespace support
+} // namespace tflite
+
diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
new file mode 100644
index 00000000..2e50a799
--- /dev/null
+++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h
@@ -0,0 +1,41 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_UTILS_H_
+#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_UTILS_H_
+
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace support {
+namespace text {
+namespace tokenizer {
+
+
+// Create a Tokenizer from model metadata by extracting
+tflite::support::StatusOr<std::unique_ptr<Tokenizer>>
+CreateTokenizerFromProcessUnit(
+ const tflite::ProcessUnit* tokenizer_process_unit,
+ const tflite::metadata::ModelMetadataExtractor* metadata_extractor);
+
+} // namespace tokenizer
+} // namespace text
+} // namespace support
+} // namespace tflite
+
+#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_UTILS_H_
diff --git a/tensorflow_lite_support/cc/utils/BUILD b/tensorflow_lite_support/cc/utils/BUILD
new file mode 100644
index 00000000..07c832f2
--- /dev/null
+++ b/tensorflow_lite_support/cc/utils/BUILD
@@ -0,0 +1,32 @@
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "jni_utils",
+ srcs = [
+ "jni_utils.cc",
+ ],
+ hdrs = [
+ "jni_utils.h",
+ ],
+ deps = [
+ "@com_google_absl//absl/strings",
+ "@org_tensorflow//tensorflow/lite/java/jni",
+ ],
+)
+
+cc_library(
+ name = "common_utils",
+ srcs = [
+ "common_utils.cc",
+ ],
+ hdrs = [
+ "common_utils.h",
+ ],
+ deps = [
+ "@com_google_absl//absl/container:node_hash_map",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow_lite_support/cc/utils/common_utils.cc b/tensorflow_lite_support/cc/utils/common_utils.cc
new file mode 100644
index 00000000..61996f47
--- /dev/null
+++ b/tensorflow_lite_support/cc/utils/common_utils.cc
@@ -0,0 +1,96 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/utils/common_utils.h"
+
+#include <fstream>
+
+#include "absl/strings/str_split.h"
+
+namespace tflite {
+namespace support {
+namespace utils {
+namespace {
+struct membuf : std::streambuf {
+ membuf(char* begin, char* end) { this->setg(begin, begin, end); }
+};
+
+void ReadIStreamLineByLine(
+ std::istream* istream,
+ const std::function<void(std::string)>& line_processor) {
+ std::string str;
+ while (std::getline(*istream, str)) {
+ if (!str.empty()) {
+ line_processor(str);
+ }
+ }
+}
+
+absl::node_hash_map<std::string, int> ReadIStreamLineSplits(
+ std::istream* istream) {
+ absl::node_hash_map<std::string, int> vocab_index_map;
+ std::string str;
+ ReadIStreamLineByLine(istream, [&vocab_index_map](const std::string& str) {
+ std::vector<std::string> v = absl::StrSplit(str, ' ');
+ vocab_index_map[v[0]] = std::stoi(v[1]);
+ });
+ return vocab_index_map;
+}
+
+std::vector<std::string> ReadIStreamByLine(std::istream* istream) {
+ std::vector<std::string> vocab_from_file;
+ std::string str;
+
+ ReadIStreamLineByLine(istream, [&vocab_from_file](const std::string& str) {
+ vocab_from_file.push_back(str);
+ });
+ return vocab_from_file;
+}
+
+} // namespace
+
+std::vector<std::string> LoadVocabFromFile(const std::string& path_to_vocab) {
+ std::vector<std::string> vocab_from_file;
+ std::ifstream in(path_to_vocab.c_str());
+ return ReadIStreamByLine(&in);
+}
+
+std::vector<std::string> LoadVocabFromBuffer(const char* vocab_buffer_data,
+ const size_t vocab_buffer_size) {
+ membuf sbuf(const_cast<char*>(vocab_buffer_data),
+ const_cast<char*>(vocab_buffer_data + vocab_buffer_size));
+ std::istream in(&sbuf);
+ return ReadIStreamByLine(&in);
+}
+
+absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile(
+ const std::string& path_to_vocab) {
+ absl::node_hash_map<std::string, int> vocab_index_map;
+ std::ifstream in(path_to_vocab.c_str());
+ return ReadIStreamLineSplits(&in);
+}
+
+absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer(
+ const char* vocab_buffer_data, const size_t vocab_buffer_size) {
+ membuf sbuf(const_cast<char*>(vocab_buffer_data),
+ const_cast<char*>(vocab_buffer_data + vocab_buffer_size));
+ absl::node_hash_map<std::string, int> vocab_index_map;
+ std::istream in(&sbuf);
+ return ReadIStreamLineSplits(&in);
+}
+
+} // namespace utils
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/utils/common_utils.h b/tensorflow_lite_support/cc/utils/common_utils.h
new file mode 100644
index 00000000..36232230
--- /dev/null
+++ b/tensorflow_lite_support/cc/utils/common_utils.h
@@ -0,0 +1,49 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_UTILS_COMMON_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_UTILS_COMMON_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/container/node_hash_map.h"
+
+namespace tflite {
+namespace support {
+namespace utils {
+
+// Read a vocab file with one vocabulary on each line, create a vector of
+// strings.
+std::vector<std::string> LoadVocabFromFile(const std::string& path_to_vocab);
+
+// read a vocab buffer with one vocab one each line, create a vector of strings
+std::vector<std::string> LoadVocabFromBuffer(const char* vocab_buffer_data,
+ const size_t vocab_buffer_size);
+
+// Read a vocab file with one vocabulary and its corresponding index on each
+// line separated by space, create a map of <vocab, index>.
+absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile(
+ const std::string& path_to_vocab);
+
+// Read a vocab buffer with one vocabulary and its corresponding index on each
+// line separated by space, create a map of <vocab, index>.
+absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer(
+ const char* vocab_buffer_data, const size_t vocab_buffer_size);
+} // namespace utils
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_UTILS_COMMON_UTILS_H_
diff --git a/tensorflow_lite_support/cc/utils/jni_utils.cc b/tensorflow_lite_support/cc/utils/jni_utils.cc
new file mode 100644
index 00000000..25cf3266
--- /dev/null
+++ b/tensorflow_lite_support/cc/utils/jni_utils.cc
@@ -0,0 +1,100 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+
+#include <string.h>
+
+namespace tflite {
+namespace support {
+namespace utils {
+
+std::string JStringToString(JNIEnv* env, jstring jstr) {
+ if (jstr == nullptr) {
+ return std::string();
+ }
+ const char* cstring = env->GetStringUTFChars(jstr, nullptr);
+ std::string result(cstring);
+ env->ReleaseStringUTFChars(jstr, cstring);
+ return result;
+}
+
+std::vector<std::string> StringListToVector(JNIEnv* env, jobject list_object) {
+ jobject j_iterator = env->CallObjectMethod(
+ list_object, env->GetMethodID(env->GetObjectClass(list_object),
+ "iterator", "()Ljava/util/Iterator;"));
+ std::vector<std::string> result;
+ jmethodID has_next =
+ env->GetMethodID(env->GetObjectClass(j_iterator), "hasNext", "()Z");
+ jmethodID get_next = env->GetMethodID(env->GetObjectClass(j_iterator), "next",
+ "()Ljava/lang/Object;");
+ while (env->CallBooleanMethod(j_iterator, has_next)) {
+ jstring jstr =
+ static_cast<jstring>(env->CallObjectMethod(j_iterator, get_next));
+ const char* raw_str = env->GetStringUTFChars(jstr, JNI_FALSE);
+ result.emplace_back(std::string(raw_str));
+ env->ReleaseStringUTFChars(jstr, raw_str);
+ }
+ return result;
+}
+
+absl::string_view GetMappedFileBuffer(JNIEnv* env, const jobject& file_buffer) {
+ return absl::string_view(
+ static_cast<char*>(env->GetDirectBufferAddress(file_buffer)),
+ static_cast<size_t>(env->GetDirectBufferCapacity(file_buffer)));
+}
+
+jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes) {
+ jbyteArray ret = env->NewByteArray(num_bytes);
+ env->SetByteArrayRegion(ret, 0, num_bytes, data);
+
+ return ret;
+}
+
+void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...) {
+ va_list args;
+ va_start(args, fmt);
+ const size_t max_msg_len = 512;
+ auto* message = static_cast<char*>(malloc(max_msg_len));
+ if (message && (vsnprintf(message, max_msg_len, fmt, args) >= 0)) {
+ ThrowExceptionWithMessage(env, clazz, message);
+ } else {
+ ThrowExceptionWithMessage(env, clazz, "");
+ }
+ if (message) {
+ free(message);
+ }
+ va_end(args);
+}
+
+void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz,
+ const char* message) {
+ jclass e_class = env->FindClass(clazz);
+ if (strcmp(clazz, kAssertionError) == 0) {
+ // AssertionError cannot use ThrowNew in Java 7
+ jmethodID constructor =
+ env->GetMethodID(e_class, "<init>", "(Ljava/lang/Object;)V");
+ jstring jstr_message = env->NewStringUTF(message);
+ jobject e_object = env->NewObject(e_class, constructor,
+ static_cast<jobject>(jstr_message));
+ env->Throw(static_cast<jthrowable>(e_object));
+ return;
+ }
+ env->ThrowNew(e_class, message);
+}
+
+} // namespace utils
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/cc/utils/jni_utils.h b/tensorflow_lite_support/cc/utils/jni_utils.h
new file mode 100644
index 00000000..4a4aae46
--- /dev/null
+++ b/tensorflow_lite_support/cc/utils/jni_utils.h
@@ -0,0 +1,91 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_UTILS_JNI_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_UTILS_JNI_UTILS_H_
+
+#include <jni.h>
+
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+
+namespace tflite {
+namespace support {
+namespace utils {
+
+const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException";
+const char kIllegalStateException[] = "java/lang/IllegalStateException";
+const char kNullPointerException[] = "java/lang/NullPointerException";
+const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException";
+const char kUnsupportedOperationException[] =
+ "java/lang/UnsupportedOperationException";
+const char kAssertionError[] = "java/lang/AssertionError";
+
+constexpr int kInvalidPointer = 0;
+
+// Check if t is nullptr, throw IllegalStateException if it is.
+// Used to verify different types of jobjects are correctly created from jni.
+template <typename T>
+T CheckNotNull(JNIEnv* env, T&& t) {
+ if (t == nullptr) {
+ env->ThrowNew(env->FindClass(kIllegalStateException), "");
+ return nullptr;
+ }
+ return std::forward<T>(t);
+}
+
+// Converts a std::vector<T> into a Java ArrayList using a converter, which
+// processes a single element in the vector before adding it to the ArrayList.
+template <typename T>
+jobject ConvertVectorToArrayList(JNIEnv* env, const std::vector<T>& results,
+ std::function<jobject(T)> converter) {
+ jclass array_list_class = env->FindClass("java/util/ArrayList");
+ jmethodID array_list_ctor =
+ env->GetMethodID(array_list_class, "<init>", "(I)V");
+ jint initial_capacity = static_cast<jint>(results.size());
+ jobject array_list_object =
+ env->NewObject(array_list_class, array_list_ctor, initial_capacity);
+ jmethodID array_list_add_method =
+ env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z");
+
+ for (const auto& ans : results) {
+ env->CallBooleanMethod(array_list_object, array_list_add_method,
+ converter(ans));
+ }
+ return array_list_object;
+}
+
+std::string JStringToString(JNIEnv* env, jstring jstr);
+
+std::vector<std::string> StringListToVector(JNIEnv* env, jobject list_object);
+
+// Gets a mapped file buffer from a java object representing a file.
+absl::string_view GetMappedFileBuffer(JNIEnv* env, const jobject& file_buffer);
+
+// Creates a Java byte array object based on the input data.
+jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes);
+
+void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...);
+
+void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz,
+ const char* message);
+
+} // namespace utils
+} // namespace support
+} // namespace tflite
+#endif // TENSORFLOW_LITE_SUPPORT_CC_UTILS_JNI_UTILS_H_
diff --git a/tensorflow_lite_support/codegen/BUILD b/tensorflow_lite_support/codegen/BUILD
new file mode 100644
index 00000000..b224f987
--- /dev/null
+++ b/tensorflow_lite_support/codegen/BUILD
@@ -0,0 +1,86 @@
+# The tools for generating wrapper classes for a TFLite model with metadata.
+
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "utils",
+ srcs = [
+ "utils.cc",
+ ],
+ hdrs = [
+ "utils.h",
+ ],
+ deps = [
+ ],
+)
+
+cc_library(
+ name = "code_generator",
+ srcs = [
+ "code_generator.cc",
+ ],
+ hdrs = [
+ "code_generator.h",
+ ],
+ deps = [
+ ":utils",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ ],
+)
+
+cc_library(
+ name = "metadata_helper",
+ srcs = [
+ "metadata_helper.cc",
+ ],
+ hdrs = [
+ "metadata_helper.h",
+ ],
+ deps = [
+ ":utils",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
+
+cc_library(
+ name = "android_java_generator",
+ srcs = [
+ "android_java_generator.cc",
+ ],
+ hdrs = [
+ "android_java_generator.h",
+ ],
+ deps = [
+ ":code_generator",
+ ":metadata_helper",
+ ":utils",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "code_generator_test",
+ size = "small",
+ srcs = ["code_generator_test.cc"],
+ data = ["//tensorflow_lite_support/metadata:metadata_schema.fbs"],
+ deps = [
+ ":code_generator",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_test(
+ name = "utils_test",
+ srcs = ["utils_test.cc"],
+ deps = [
+ ":utils",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow_lite_support/codegen/README.md b/tensorflow_lite_support/codegen/README.md
new file mode 100644
index 00000000..d457edd1
--- /dev/null
+++ b/tensorflow_lite_support/codegen/README.md
@@ -0,0 +1,13 @@
+# TensorFlow Lite Android Wrapper Code Generator
+
+For TensorFlow Lite model enhanced with [metadata](https://www.tensorflow.org/lite/convert/metadata.md),
+developers can use the TensorFlow Lite Android wrapper code generator to create
+platform specific wrapper code. The wrapper code removes the need to interact
+directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow
+Lite model with typed objects such as `Bitmap` and `Rect`.
+
+The usefulness of the code generator depend on the completeness of the
+TensorFlow Lite model's metadata entry. Refer to the `<Codegen usage>` section
+under relevant fields in
+[metadata_schema.fbs](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/metadata_schema.fbs),
+to see how the codegen tool parses each field.
diff --git a/tensorflow_lite_support/codegen/android_java_generator.cc b/tensorflow_lite_support/codegen/android_java_generator.cc
new file mode 100644
index 00000000..097119f1
--- /dev/null
+++ b/tensorflow_lite_support/codegen/android_java_generator.cc
@@ -0,0 +1,1017 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file contains the logic of android model wrapper generation.
+//
+// At the beginning is the helper functions handling metadata and code writer.
+//
+// Codes are generated in every `Generate{FOO}` functions. Gradle and Manifest
+// files are simple. The wrapper file generation is a bit complex so we divided
+// it into several sub-functions.
+//
+// The structure of the wrapper file looks like:
+//
+// [ imports ]
+// [ class ]
+// [ inner "Outputs" class ]
+// [ innner "Metadata" class ]
+// [ APIs ] ( including ctors, public APIs and private APIs )
+//
+// We tried to mostly write it in a "template-generation" way. `CodeWriter` does
+// the job as a template renderer. To avoid repeatedly setting the token values,
+// helper functions `SetCodeWriterWith{Foo}Info` set the token values with info
+// structures (`TensorInfo` and `ModelInfo`) - the Info structures are
+// intermediate datastructures between Metadata (represented in Flatbuffers) and
+// generated code.
+
+#include "tensorflow_lite_support/codegen/android_java_generator.h"
+
+#include <ctype.h>
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow_lite_support/codegen/code_generator.h"
+#include "tensorflow_lite_support/codegen/metadata_helper.h"
+#include "tensorflow_lite_support/codegen/utils.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+namespace {
+
+using details_android_java::ModelInfo;
+using details_android_java::TensorInfo;
+
+// Helper class to organize the C++ code block as a generated code block.
+// Using ctor and dtor to simulate an enter/exit schema like `with` in Python.
+class AsBlock {
+ public:
+ AsBlock(CodeWriter* code_writer, const std::string& before,
+ bool trailing_blank_line = false)
+ : code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) {
+ code_writer_->AppendNoNewLine(before);
+ code_writer_->Append(" {");
+ code_writer_->Indent();
+ }
+ ~AsBlock() {
+ code_writer_->Outdent();
+ code_writer_->Append("}");
+ if (trailing_blank_line_) {
+ code_writer_->NewLine();
+ }
+ }
+
+ private:
+ CodeWriter* code_writer_;
+ bool trailing_blank_line_;
+};
+
+// Declare the functions first, so that the functions can follow a logical
+// order.
+bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*);
+bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*);
+bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
+bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*);
+bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*);
+bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*);
+
+std::string GetModelVersionedName(const ModelMetadata* metadata) {
+ std::string model_name = "MyModel";
+ if (metadata->name() != nullptr && !(metadata->name()->str().empty())) {
+ model_name = metadata->name()->str();
+ }
+ std::string model_version = "unknown";
+ if (metadata->version() != nullptr && !(metadata->version()->str().empty())) {
+ model_version = metadata->version()->str();
+ }
+ return model_name + " (Version: " + model_version + ")";
+}
+
+TensorInfo CreateTensorInfo(const TensorMetadata* metadata,
+ const std::string& name, bool is_input, int index,
+ ErrorReporter* err) {
+ TensorInfo tensor_info;
+ std::string tensor_identifier = is_input ? "input" : "output";
+ tensor_identifier += " " + std::to_string(index);
+ tensor_info.associated_axis_label_index = FindAssociatedFile(
+ metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err);
+ tensor_info.associated_value_label_index = FindAssociatedFile(
+ metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err);
+ if (is_input && (tensor_info.associated_axis_label_index >= 0 ||
+ tensor_info.associated_value_label_index >= 0)) {
+ err->Warning(
+ "Found label file on input tensor (%s). Label file for input "
+ "tensor is not supported yet. The "
+ "file will be ignored.",
+ tensor_identifier.c_str());
+ }
+ if (tensor_info.associated_axis_label_index >= 0 &&
+ tensor_info.associated_value_label_index >= 0) {
+ err->Warning(
+ "Found both axis label file and value label file for tensor (%s), "
+ "which is not supported. Only the axis label file will be used.",
+ tensor_identifier.c_str());
+ }
+ tensor_info.is_input = is_input;
+ tensor_info.name = SnakeCaseToCamelCase(name);
+ tensor_info.upper_camel_name = tensor_info.name;
+ tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]);
+ tensor_info.normalization_unit =
+ FindNormalizationUnit(metadata, tensor_identifier, err);
+ if (metadata->content() != nullptr &&
+ metadata->content()->content_properties() != nullptr) {
+ // Enter tensor wrapper type inferring
+ if (metadata->content()->content_properties_type() ==
+ ContentProperties_ImageProperties) {
+ if (metadata->content()
+ ->content_properties_as_ImageProperties()
+ ->color_space() == ColorSpaceType_RGB) {
+ tensor_info.content_type = "image";
+ tensor_info.wrapper_type = "TensorImage";
+ tensor_info.processor_type = "ImageProcessor";
+ return tensor_info;
+ } else {
+ err->Warning(
+ "Found Non-RGB image on tensor (%s). Codegen currently does not "
+ "support it, and regard it as a plain numeric tensor.",
+ tensor_identifier.c_str());
+ }
+ }
+ }
+ tensor_info.content_type = "tensor";
+ tensor_info.wrapper_type = "TensorBuffer";
+ tensor_info.processor_type = "TensorProcessor";
+ return tensor_info;
+}
+
+ModelInfo CreateModelInfo(const ModelMetadata* metadata,
+ const std::string& package_name,
+ const std::string& model_class_name,
+ const std::string& model_asset_path,
+ ErrorReporter* err) {
+ ModelInfo model_info;
+ if (!CodeGenerator::VerifyMetadata(metadata, err)) {
+ // TODO(b/150116380): Create dummy model info.
+ err->Error("Validating metadata failed.");
+ return model_info;
+ }
+ model_info.package_name = package_name;
+ model_info.model_class_name = model_class_name;
+ model_info.model_asset_path = model_asset_path;
+ model_info.model_versioned_name = GetModelVersionedName(metadata);
+ const auto* graph = metadata->subgraph_metadata()->Get(0);
+ auto names = CodeGenerator::NameInputsAndOutputs(
+ graph->input_tensor_metadata(), graph->output_tensor_metadata());
+ std::vector<std::string> input_tensor_names = std::move(names.first);
+ std::vector<std::string> output_tensor_names = std::move(names.second);
+
+ for (int i = 0; i < input_tensor_names.size(); i++) {
+ model_info.inputs.push_back(
+ CreateTensorInfo(graph->input_tensor_metadata()->Get(i),
+ input_tensor_names[i], true, i, err));
+ if (i < input_tensor_names.size() - 1) {
+ model_info.inputs_list += ", ";
+ model_info.input_type_param_list += ", ";
+ }
+ model_info.inputs_list += model_info.inputs[i].name;
+ model_info.input_type_param_list +=
+ model_info.inputs[i].wrapper_type + " " + model_info.inputs[i].name;
+ }
+ for (int i = 0; i < output_tensor_names.size(); i++) {
+ model_info.outputs.push_back(
+ CreateTensorInfo(graph->output_tensor_metadata()->Get(i),
+ output_tensor_names[i], false, i, err));
+ if (i < output_tensor_names.size() - 1) {
+ model_info.postprocessor_type_param_list += ", ";
+ model_info.postprocessors_list += ", ";
+ }
+ model_info.postprocessors_list +=
+ model_info.outputs[i].name + "Postprocessor";
+ model_info.postprocessor_type_param_list +=
+ model_info.outputs[i].processor_type + " " +
+ model_info.outputs[i].name + "Postprocessor";
+ }
+ return model_info;
+}
+
+void SetCodeWriterWithTensorInfo(CodeWriter* code_writer,
+ const TensorInfo& tensor_info) {
+ code_writer->SetTokenValue("NAME", tensor_info.name);
+ code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name);
+ code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type);
+ code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type);
+ std::string wrapper_name = tensor_info.wrapper_type;
+ wrapper_name[0] = tolower(wrapper_name[0]);
+ code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name);
+ code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type);
+ code_writer->SetTokenValue("NORMALIZATION_UNIT",
+ std::to_string(tensor_info.normalization_unit));
+ code_writer->SetTokenValue(
+ "ASSOCIATED_AXIS_LABEL_INDEX",
+ std::to_string(tensor_info.associated_axis_label_index));
+ code_writer->SetTokenValue(
+ "ASSOCIATED_VALUE_LABEL_INDEX",
+ std::to_string(tensor_info.associated_value_label_index));
+}
+
+void SetCodeWriterWithModelInfo(CodeWriter* code_writer,
+ const ModelInfo& model_info) {
+ code_writer->SetTokenValue("PACKAGE", model_info.package_name);
+ code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path);
+ code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name);
+ // Extra info, half generated.
+ code_writer->SetTokenValue("INPUT_TYPE_PARAM_LIST",
+ model_info.input_type_param_list);
+ code_writer->SetTokenValue("INPUTS_LIST", model_info.inputs_list);
+ code_writer->SetTokenValue("POSTPROCESSORS_LIST",
+ model_info.postprocessors_list);
+ code_writer->SetTokenValue("POSTPROCESSOR_TYPE_PARAM_LIST",
+ model_info.postprocessor_type_param_list);
+}
+
+constexpr char JAVA_DEFAULT_PACKAGE[] = "default";
+
+std::string ConvertPackageToPath(const std::string& package) {
+ if (package == JAVA_DEFAULT_PACKAGE) {
+ return "";
+ }
+ std::string path = package;
+ std::replace(path.begin(), path.end(), '.', '/');
+ return path;
+}
+
+bool IsImageUsed(const ModelInfo& model) {
+ for (const auto& input : model.inputs) {
+ if (input.content_type == "image") {
+ return true;
+ }
+ }
+ for (const auto& output : model.outputs) {
+ if (output.content_type == "image") {
+ return true;
+ }
+ }
+ return false;
+}
+
+// The following functions generates the wrapper Java code for a model.
+
+bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model,
+ ErrorReporter* err) {
+ code_writer->Append("// Generated by TFLite Support.");
+ code_writer->Append("package {{PACKAGE}};");
+ code_writer->NewLine();
+
+ if (!GenerateWrapperImports(code_writer, model, err)) {
+ err->Error("Fail to generate imports for wrapper class.");
+ return false;
+ }
+ if (!GenerateWrapperClass(code_writer, model, err)) {
+ err->Error("Fail to generate wrapper class.");
+ return false;
+ }
+ code_writer->NewLine();
+ return true;
+}
+
+bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model,
+ ErrorReporter* err) {
+ const std::string support_pkg = "org.tensorflow.lite.support.";
+ std::vector<std::string> imports{
+ "android.content.Context",
+ "java.io.IOException",
+ "java.nio.ByteBuffer",
+ "java.nio.FloatBuffer",
+ "java.util.Arrays",
+ "java.util.HashMap",
+ "java.util.List",
+ "java.util.Map",
+ "org.tensorflow.lite.DataType",
+ "org.tensorflow.lite.Tensor",
+ "org.tensorflow.lite.Tensor.QuantizationParams",
+ support_pkg + "common.FileUtil",
+ support_pkg + "common.TensorProcessor",
+ support_pkg + "common.ops.CastOp",
+ support_pkg + "common.ops.DequantizeOp",
+ support_pkg + "common.ops.NormalizeOp",
+ support_pkg + "common.ops.QuantizeOp",
+ support_pkg + "label.Category",
+ support_pkg + "label.TensorLabel",
+ support_pkg + "metadata.MetadataExtractor",
+ support_pkg + "metadata.schema.NormalizationOptions",
+ support_pkg + "model.Model",
+ support_pkg + "tensorbuffer.TensorBuffer",
+ };
+ if (IsImageUsed(model)) {
+ for (const auto& target :
+ {"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp",
+ "image.ops.ResizeOp.ResizeMethod"}) {
+ imports.push_back(support_pkg + target);
+ }
+ }
+
+ std::sort(imports.begin(), imports.end());
+ for (const auto& target : imports) {
+ code_writer->SetTokenValue("TARGET", target);
+ code_writer->Append("import {{TARGET}};");
+ }
+ code_writer->NewLine();
+ return true;
+}
+
+bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model,
+ ErrorReporter* err) {
+ code_writer->SetTokenValue("MODEL_VERSIONED_NAME",
+ model.model_versioned_name);
+ code_writer->Append(
+ R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)");
+ const auto code_block =
+ AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}");
+ code_writer->Append(R"(private final Metadata metadata;
+private final Model model;
+private static final String MODEL_NAME = "{{MODEL_PATH}}";)");
+ for (const auto& tensor : model.inputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;");
+ }
+ for (const auto& tensor : model.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
+ }
+ code_writer->NewLine();
+ if (!GenerateWrapperOutputs(code_writer, model, err)) {
+ err->Error("Failed to generate output classes");
+ return false;
+ }
+ code_writer->NewLine();
+ if (!GenerateWrapperMetadata(code_writer, model, err)) {
+ err->Error("Failed to generate the metadata class");
+ return false;
+ }
+ code_writer->NewLine();
+ if (!GenerateWrapperAPI(code_writer, model, err)) {
+ err->Error("Failed to generate the common APIs");
+ return false;
+ }
+ return true;
+}
+
+bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model,
+ ErrorReporter* err) {
+ code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */");
+ auto class_block = AsBlock(code_writer, "public static class Outputs");
+ for (const auto& tensor : model.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};");
+ if (tensor.associated_axis_label_index >= 0) {
+ code_writer->Append("private final List<String> {{NAME}}Labels;");
+ }
+ code_writer->Append(
+ "private final {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;");
+ }
+ // Getters
+ for (const auto& tensor : model.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->NewLine();
+ if (tensor.associated_axis_label_index >= 0) {
+ if (tensor.content_type == "tensor") {
+ code_writer->Append(
+ R"(public List<Category> get{{NAME_U}}AsCategoryList() {
+ return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getCategoryList();
+})");
+ } else { // image
+ err->Warning(
+ "Axis label for images is not supported. The labels will "
+ "be ignored.");
+ }
+ } else { // no label
+ code_writer->Append(
+ R"(public {{WRAPPER_TYPE}} get{{NAME_U}}As{{WRAPPER_TYPE}}() {
+ return postprocess{{NAME_U}}({{NAME}});
+})");
+ }
+ }
+ code_writer->NewLine();
+ {
+ const auto ctor_block = AsBlock(
+ code_writer,
+ "Outputs(Metadata metadata, {{POSTPROCESSOR_TYPE_PARAM_LIST}})");
+ for (const auto& tensor : model.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ if (tensor.content_type == "image") {
+ code_writer->Append(
+ R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type());
+{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)");
+ } else { // FEATURE, UNKNOWN
+ code_writer->Append(
+ "{{NAME}} = "
+ "TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), "
+ "metadata.get{{NAME_U}}Type());");
+ }
+ if (tensor.associated_axis_label_index >= 0) {
+ code_writer->Append("{{NAME}}Labels = metadata.get{{NAME_U}}Labels();");
+ }
+ code_writer->Append(
+ "this.{{NAME}}Postprocessor = {{NAME}}Postprocessor;");
+ }
+ }
+ code_writer->NewLine();
+ {
+ const auto get_buffer_block =
+ AsBlock(code_writer, "Map<Integer, Object> getBuffer()");
+ code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();");
+ for (int i = 0; i < model.outputs.size(); i++) {
+ SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
+ code_writer->SetTokenValue("ID", std::to_string(i));
+ code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());");
+ }
+ code_writer->Append("return outputs;");
+ }
+ for (const auto& tensor : model.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->NewLine();
+ {
+ auto processor_block =
+ AsBlock(code_writer,
+ "private {{WRAPPER_TYPE}} "
+ "postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})");
+ code_writer->Append(
+ "return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});");
+ }
+ }
+ return true;
+}
+
+bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model,
+ ErrorReporter* err) {
+ code_writer->Append(
+ "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */");
+ const auto class_block = AsBlock(code_writer, "public static class Metadata");
+ for (const auto& tensor : model.inputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append(R"(private final int[] {{NAME}}Shape;
+private final DataType {{NAME}}DataType;
+private final QuantizationParams {{NAME}}QuantizationParams;)");
+ if (tensor.normalization_unit >= 0) {
+ code_writer->Append(R"(private final float[] {{NAME}}Mean;
+private final float[] {{NAME}}Stddev;)");
+ }
+ }
+ for (const auto& tensor : model.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append(R"(private final int[] {{NAME}}Shape;
+private final DataType {{NAME}}DataType;
+private final QuantizationParams {{NAME}}QuantizationParams;)");
+ if (tensor.normalization_unit >= 0) {
+ code_writer->Append(R"(private final float[] {{NAME}}Mean;
+private final float[] {{NAME}}Stddev;)");
+ }
+ if (tensor.associated_axis_label_index >= 0 ||
+ tensor.associated_value_label_index >= 0) {
+ code_writer->Append("private final List<String> {{NAME}}Labels;");
+ }
+ }
+ code_writer->NewLine();
+ {
+ const auto ctor_block = AsBlock(
+ code_writer,
+ "public Metadata(ByteBuffer buffer, Model model) throws IOException");
+ code_writer->Append(
+ "MetadataExtractor extractor = new MetadataExtractor(buffer);");
+ for (int i = 0; i < model.inputs.size(); i++) {
+ SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]);
+ code_writer->SetTokenValue("ID", std::to_string(i));
+ code_writer->Append(
+ R"(Tensor {{NAME}}Tensor = model.getInputTensor({{ID}});
+{{NAME}}Shape = {{NAME}}Tensor.shape();
+{{NAME}}DataType = {{NAME}}Tensor.dataType();
+{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
+ if (model.inputs[i].normalization_unit >= 0) {
+ code_writer->Append(
+ R"(NormalizationOptions {{NAME}}NormalizationOptions =
+ (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
+FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
+{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
+{{NAME}}MeanBuffer.get({{NAME}}Mean);
+FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
+{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
+{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
+ }
+ }
+ for (int i = 0; i < model.outputs.size(); i++) {
+ SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]);
+ code_writer->SetTokenValue("ID", std::to_string(i));
+ code_writer->Append(
+ R"(Tensor {{NAME}}Tensor = model.getOutputTensor({{ID}});
+{{NAME}}Shape = {{NAME}}Tensor.shape();
+{{NAME}}DataType = {{NAME}}Tensor.dataType();
+{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)");
+ if (model.outputs[i].normalization_unit >= 0) {
+ code_writer->Append(
+ R"(NormalizationOptions {{NAME}}NormalizationOptions =
+ (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions());
+FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer();
+{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()];
+{{NAME}}MeanBuffer.get({{NAME}}Mean);
+FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer();
+{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()];
+{{NAME}}StddevBuffer.get({{NAME}}Stddev);)");
+ }
+ if (model.outputs[i].associated_axis_label_index >= 0) {
+ code_writer->Append(R"(String {{NAME}}LabelsFileName =
+ extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name();
+{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
+ } else if (model.outputs[i].associated_value_label_index >= 0) {
+ code_writer->Append(R"(String {{NAME}}LabelsFileName =
+ extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name();
+{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)");
+ }
+ }
+ }
+ for (const auto& tensor : model.inputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append(R"(
+public int[] get{{NAME_U}}Shape() {
+ return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
+}
+
+public DataType get{{NAME_U}}Type() {
+ return {{NAME}}DataType;
+}
+
+public QuantizationParams get{{NAME_U}}QuantizationParams() {
+ return {{NAME}}QuantizationParams;
+})");
+ if (tensor.normalization_unit >= 0) {
+ code_writer->Append(R"(
+public float[] get{{NAME_U}}Mean() {
+ return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
+}
+
+public float[] get{{NAME_U}}Stddev() {
+ return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
+})");
+ }
+ }
+ for (const auto& tensor : model.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append(R"(
+public int[] get{{NAME_U}}Shape() {
+ return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length);
+}
+
+public DataType get{{NAME_U}}Type() {
+ return {{NAME}}DataType;
+}
+
+public QuantizationParams get{{NAME_U}}QuantizationParams() {
+ return {{NAME}}QuantizationParams;
+})");
+ if (tensor.normalization_unit >= 0) {
+ code_writer->Append(R"(
+public float[] get{{NAME_U}}Mean() {
+ return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length);
+}
+
+public float[] get{{NAME_U}}Stddev() {
+ return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length);
+})");
+ }
+ if (tensor.associated_axis_label_index >= 0 ||
+ tensor.associated_value_label_index >= 0) {
+ code_writer->Append(R"(
+public List<String> get{{NAME_U}}Labels() {
+ return {{NAME}}Labels;
+})");
+ }
+ }
+ return true;
+}
+
+bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model,
+ ErrorReporter* err) {
+ code_writer->Append(R"(public Metadata getMetadata() {
+ return metadata;
+}
+)");
+ code_writer->Append(R"(/**
+ * Creates interpreter and loads associated files if needed.
+ *
+ * @throws IOException if an I/O error occurs when loading the tflite model.
+ */
+public static {{MODEL_CLASS_NAME}} newInstance(Context context) throws IOException {
+ return newInstance(context, MODEL_NAME, new Model.Options.Builder().build());
+}
+
+/**
+ * Creates interpreter and loads associated files if needed, but loading another model in the same
+ * input / output structure with the original one.
+ *
+ * @throws IOException if an I/O error occurs when loading the tflite model.
+ */
+public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath) throws IOException {
+ return newInstance(context, modelPath, new Model.Options.Builder().build());
+}
+
+/**
+ * Creates interpreter and loads associated files if needed, with running options configured.
+ *
+ * @throws IOException if an I/O error occurs when loading the tflite model.
+ */
+public static {{MODEL_CLASS_NAME}} newInstance(Context context, Model.Options runningOptions) throws IOException {
+ return newInstance(context, MODEL_NAME, runningOptions);
+}
+
+/**
+ * Creates interpreter for a user-specified model.
+ *
+ * @throws IOException if an I/O error occurs when loading the tflite model.
+ */
+public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath, Model.Options runningOptions) throws IOException {
+ Model model = Model.createModel(context, modelPath, runningOptions);
+ Metadata metadata = new Metadata(model.getData(), model);
+ MyImageClassifier instance = new MyImageClassifier(model, metadata);)");
+ for (const auto& tensor : model.inputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append(
+ R"( instance.reset{{NAME_U}}Preprocessor(
+ instance.buildDefault{{NAME_U}}Preprocessor());)");
+ }
+ for (const auto& tensor : model.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append(
+ R"( instance.reset{{NAME_U}}Postprocessor(
+ instance.buildDefault{{NAME_U}}Postprocessor());)");
+ }
+ code_writer->Append(R"( return instance;
+}
+)");
+
+ // Pre, post processor setters
+ for (const auto& tensor : model.inputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append(R"(
+public void reset{{NAME_U}}Preprocessor({{PROCESSOR_TYPE}} processor) {
+ {{NAME}}Preprocessor = processor;
+})");
+ }
+ for (const auto& tensor : model.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append(R"(
+public void reset{{NAME_U}}Postprocessor({{PROCESSOR_TYPE}} processor) {
+ {{NAME}}Postprocessor = processor;
+})");
+ }
+ // Process method
+ code_writer->Append(R"(
+/** Triggers the model. */
+public Outputs process({{INPUT_TYPE_PARAM_LIST}}) {
+ Outputs outputs = new Outputs(metadata, {{POSTPROCESSORS_LIST}});
+ Object[] inputBuffers = preprocessInputs({{INPUTS_LIST}});
+ model.run(inputBuffers, outputs.getBuffer());
+ return outputs;
+}
+
+/** Closes the model. */
+public void close() {
+ model.close();
+}
+)");
+ {
+ auto block =
+ AsBlock(code_writer,
+ "private {{MODEL_CLASS_NAME}}(Model model, Metadata metadata)");
+ code_writer->Append(R"(this.model = model;
+this.metadata = metadata;)");
+ }
+ for (const auto& tensor : model.inputs) {
+ code_writer->NewLine();
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ auto block = AsBlock(
+ code_writer,
+ "private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Preprocessor()");
+ code_writer->Append(
+ "{{PROCESSOR_TYPE}}.Builder builder = new "
+ "{{PROCESSOR_TYPE}}.Builder()");
+ if (tensor.content_type == "image") {
+ code_writer->Append(R"( .add(new ResizeOp(
+ metadata.get{{NAME_U}}Shape()[1],
+ metadata.get{{NAME_U}}Shape()[2],
+ ResizeMethod.NEAREST_NEIGHBOR)))");
+ }
+ if (tensor.normalization_unit >= 0) {
+ code_writer->Append(
+ R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
+ }
+ code_writer->Append(
+ R"( .add(new QuantizeOp(
+ metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
+ metadata.get{{NAME_U}}QuantizationParams().getScale()))
+ .add(new CastOp(metadata.get{{NAME_U}}Type()));
+return builder.build();)");
+ }
+ for (const auto& tensor : model.outputs) {
+ code_writer->NewLine();
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ auto block = AsBlock(
+ code_writer,
+ "private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Postprocessor()");
+ code_writer->AppendNoNewLine(
+ R"({{PROCESSOR_TYPE}}.Builder builder = new {{PROCESSOR_TYPE}}.Builder()
+ .add(new DequantizeOp(
+ metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(),
+ metadata.get{{NAME_U}}QuantizationParams().getScale())))");
+ if (tensor.normalization_unit >= 0) {
+ code_writer->AppendNoNewLine(R"(
+ .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))");
+ }
+ code_writer->Append(R"(;
+return builder.build();)");
+ }
+ code_writer->NewLine();
+ {
+ const auto block =
+ AsBlock(code_writer,
+ "private Object[] preprocessInputs({{INPUT_TYPE_PARAM_LIST}})");
+ CodeWriter param_list_gen(err);
+ for (const auto& tensor : model.inputs) {
+ SetCodeWriterWithTensorInfo(code_writer, tensor);
+ code_writer->Append("{{NAME}} = {{NAME}}Preprocessor.process({{NAME}});");
+ SetCodeWriterWithTensorInfo(&param_list_gen, tensor);
+ param_list_gen.AppendNoNewLine("{{NAME}}.getBuffer(), ");
+ }
+ param_list_gen.Backspace(2);
+ code_writer->AppendNoNewLine("return new Object[] {");
+ code_writer->AppendNoNewLine(param_list_gen.ToString());
+ code_writer->Append("};");
+ }
+ return true;
+}
+
+bool GenerateBuildGradleContent(CodeWriter* code_writer,
+ const ModelInfo& model_info) {
+ code_writer->Append(R"(buildscript {
+ repositories {
+ google()
+ jcenter()
+ }
+ dependencies {
+ classpath 'com.android.tools.build:gradle:3.2.1'
+ }
+}
+
+allprojects {
+ repositories {
+ google()
+ jcenter()
+ flatDir {
+ dirs 'libs'
+ }
+ }
+}
+
+apply plugin: 'com.android.library'
+
+android {
+ compileSdkVersion 29
+ defaultConfig {
+ targetSdkVersion 29
+ versionCode 1
+ versionName "1.0"
+ }
+ aaptOptions {
+ noCompress "tflite"
+ }
+ compileOptions {
+ sourceCompatibility = '1.8'
+ targetCompatibility = '1.8'
+ }
+ lintOptions {
+ abortOnError false
+ }
+}
+
+configurations {
+ libMetadata
+}
+
+dependencies {
+ libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic'
+}
+
+task downloadLibs(type: Sync) {
+ from configurations.libMetadata
+ into "$buildDir/libs"
+ rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar"
+}
+
+preBuild.dependsOn downloadLibs
+
+dependencies {
+ compileOnly 'org.checkerframework:checker-qual:2.5.8'
+ api 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
+ api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
+ api files("$buildDir/libs/tensorflow-lite-support-metadata.jar")
+ implementation 'org.apache.commons:commons-compress:1.19'
+})");
+ return true;
+}
+
+bool GenerateAndroidManifestContent(CodeWriter* code_writer,
+ const ModelInfo& model_info) {
+ code_writer->Append(R"(<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="{{PACKAGE}}">
+</manifest>)");
+ return true;
+}
+
+bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) {
+ code_writer->Append("# {{MODEL_CLASS_NAME}} Usage");
+ // TODO(b/158651848) Generate imports for TFLS util types like TensorImage.
+ code_writer->AppendNoNewLine(R"(
+```
+import {{PACKAGE}}.{{MODEL_CLASS_NAME}};
+
+// 1. Initialize the Model
+{{MODEL_CLASS_NAME}} model = null;
+
+try {
+ model = {{MODEL_CLASS_NAME}}.newInstance(context); // android.content.Context
+} catch (IOException e) {
+ e.printStackTrace();
+}
+
+if (model != null) {
+
+ // 2. Set the inputs)");
+ for (const auto& t : model_info.inputs) {
+ SetCodeWriterWithTensorInfo(code_writer, t);
+ if (t.content_type == "image") {
+ code_writer->Append(R"(
+ // Prepare tensor "{{NAME}}" from a Bitmap with ARGB_8888 format.
+ Bitmap bitmap = ...;
+ TensorImage {{NAME}} = TensorImage.fromBitmap(bitmap);
+ // Alternatively, load the input tensor "{{NAME}}" from pixel values.
+ // Check out TensorImage documentation to load other image data structures.
+ // int[] pixelValues = ...;
+ // int[] shape = ...;
+ // TensorImage {{NAME}} = new TensorImage();
+ // {{NAME}}.load(pixelValues, shape);)");
+ } else {
+ code_writer->Append(R"(
+ // Prepare input tensor "{{NAME}}" from an array.
+ // Check out TensorBuffer documentation to load other data structures.
+ TensorBuffer {{NAME}} = ...;
+ int[] values = ...;
+ int[] shape = ...;
+ {{NAME}}.load(values, shape);)");
+ }
+ }
+ code_writer->Append(R"(
+ // 3. Run the model
+ {{MODEL_CLASS_NAME}}.Outputs outputs = model.process({{INPUTS_LIST}});)");
+ code_writer->Append(R"(
+ // 4. Retrieve the results)");
+ for (const auto& t : model_info.outputs) {
+ SetCodeWriterWithTensorInfo(code_writer, t);
+ if (t.associated_axis_label_index >= 0) {
+ code_writer->SetTokenValue("WRAPPER_TYPE", "List<Category>");
+ code_writer->Append(
+ " List<Category> {{NAME}} = "
+ "outputs.get{{NAME_U}}AsCategoryList();");
+ } else {
+ code_writer->Append(
+ " {{WRAPPER_TYPE}} {{NAME}} = "
+ "outputs.get{{NAME_U}}As{{WRAPPER_TYPE}}();");
+ }
+ }
+ code_writer->Append(R"(}
+```)");
+ return true;
+}
+
+GenerationResult::File GenerateWrapperFile(const std::string& module_root,
+ const ModelInfo& model_info,
+ ErrorReporter* err) {
+ const auto java_path = JoinPath(module_root, "src/main/java");
+ const auto package_path =
+ JoinPath(java_path, ConvertPackageToPath(model_info.package_name));
+ const auto file_path =
+ JoinPath(package_path, model_info.model_class_name + JAVA_EXT);
+
+ CodeWriter code_writer(err);
+ code_writer.SetIndentString(" ");
+ SetCodeWriterWithModelInfo(&code_writer, model_info);
+
+ if (!GenerateWrapperFileContent(&code_writer, model_info, err)) {
+ err->Error("Generating Java wrapper content failed.");
+ }
+
+ const auto java_file = code_writer.ToString();
+ return GenerationResult::File{file_path, java_file};
+}
+
+GenerationResult::File GenerateBuildGradle(const std::string& module_root,
+ const ModelInfo& model_info,
+ ErrorReporter* err) {
+ const auto file_path = JoinPath(module_root, "build.gradle");
+ CodeWriter code_writer(err);
+ SetCodeWriterWithModelInfo(&code_writer, model_info);
+ if (!GenerateBuildGradleContent(&code_writer, model_info)) {
+ err->Error("Generating build.gradle failed.");
+ }
+ const auto content = code_writer.ToString();
+ return GenerationResult::File{file_path, content};
+}
+
+GenerationResult::File GenerateAndroidManifest(const std::string& module_root,
+ const ModelInfo& model_info,
+ ErrorReporter* err) {
+ const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml");
+ CodeWriter code_writer(err);
+ SetCodeWriterWithModelInfo(&code_writer, model_info);
+ if (!GenerateAndroidManifestContent(&code_writer, model_info)) {
+ err->Error("Generating AndroidManifest.xml failed.");
+ }
+ return GenerationResult::File{file_path, code_writer.ToString()};
+}
+
+GenerationResult::File GenerateDoc(const std::string& module_root,
+ const ModelInfo& model_info,
+ ErrorReporter* err) {
+ std::string lower = model_info.model_class_name;
+ for (int i = 0; i < lower.length(); i++) {
+ lower[i] = std::tolower(lower[i]);
+ }
+ const auto file_path = JoinPath(module_root, lower + ".md");
+ CodeWriter code_writer(err);
+ SetCodeWriterWithModelInfo(&code_writer, model_info);
+ if (!GenerateDocContent(&code_writer, model_info)) {
+ err->Error("Generating doc failed.");
+ }
+ return GenerationResult::File{file_path, code_writer.ToString()};
+}
+
+} // namespace
+
+AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root)
+ : CodeGenerator(), module_root_(module_root) {}
+
+GenerationResult AndroidJavaGenerator::Generate(
+ const Model* model, const std::string& package_name,
+ const std::string& model_class_name, const std::string& model_asset_path) {
+ GenerationResult result;
+ if (model == nullptr) {
+ err_.Error(
+ "Cannot read model from the buffer. Codegen will generate nothing.");
+ return result;
+ }
+ const ModelMetadata* metadata = GetMetadataFromModel(model);
+ if (metadata == nullptr) {
+ err_.Error(
+ "Cannot find TFLite Metadata in the model. Codegen will generate "
+ "nothing.");
+ return result;
+ }
+ details_android_java::ModelInfo model_info = CreateModelInfo(
+ metadata, package_name, model_class_name, model_asset_path, &err_);
+ result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_));
+ result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_));
+ result.files.push_back(
+ GenerateAndroidManifest(module_root_, model_info, &err_));
+ result.files.push_back(GenerateDoc(module_root_, model_info, &err_));
+ return result;
+}
+
+GenerationResult AndroidJavaGenerator::Generate(
+ const char* model_storage, const std::string& package_name,
+ const std::string& model_class_name, const std::string& model_asset_path) {
+ const Model* model = GetModel(model_storage);
+ return Generate(model, package_name, model_class_name, model_asset_path);
+}
+
+std::string AndroidJavaGenerator::GetErrorMessage() {
+ return err_.GetMessage();
+}
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/codegen/android_java_generator.h b/tensorflow_lite_support/codegen/android_java_generator.h
new file mode 100644
index 00000000..634ccf69
--- /dev/null
+++ b/tensorflow_lite_support/codegen/android_java_generator.h
@@ -0,0 +1,116 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
+#define TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow_lite_support/codegen/code_generator.h"
+#include "tensorflow_lite_support/codegen/utils.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+namespace details_android_java {
+
+/// The intermediate data structure for generating code from TensorMetadata.
+/// Should only be used as const reference when created.
+struct TensorInfo {
+ std::string name;
+ std::string upper_camel_name;
+ std::string content_type;
+ std::string wrapper_type;
+ std::string processor_type;
+ bool is_input;
+ /// Optional. Set to -1 if not applicable.
+ int normalization_unit;
+ /// Optional. Set to -1 if associated_axis_label is empty.
+ int associated_axis_label_index;
+ /// Optional. Set to -1 if associated_value_label is empty.
+ int associated_value_label_index;
+};
+
+/// The intermediate data structure for generating code from ModelMetadata.
+/// Should only be used as const reference when created.
+struct ModelInfo {
+ std::string package_name;
+ std::string model_asset_path;
+ std::string model_class_name;
+ std::string model_versioned_name;
+ std::vector<TensorInfo> inputs;
+ std::vector<TensorInfo> outputs;
+ // Extra helper fields. For models with inputs "a", "b" and outputs "x", "y":
+ std::string input_type_param_list;
+ // e.g. "TensorImage a, TensorBuffer b"
+ std::string inputs_list;
+ // e.g. "a, b"
+ std::string postprocessor_type_param_list;
+ // e.g. "ImageProcessor xPostprocessor, TensorProcessor yPostprocessor"
+ std::string postprocessors_list;
+ // e.g. "xPostprocessor, yPostprocessor"
+};
+
+} // namespace details_android_java
+
+constexpr char JAVA_EXT[] = ".java";
+
+/// Generates Android supporting codes and modules (in Java) based on TFLite
+/// metadata.
+class AndroidJavaGenerator : public CodeGenerator {
+ public:
+ /// Creates an AndroidJavaGenerator.
+ /// Args:
+ /// - module_root: The root of destination Java module.
+ explicit AndroidJavaGenerator(const std::string& module_root);
+
+ /// Generates files. Returns the file paths and contents.
+ /// Args:
+ /// - model: The TFLite model with Metadata filled.
+ /// - package_name: The name of the Java package which generated classes
+ /// belong to.
+ /// - model_class_name: A readable name of the generated wrapper class, such
+ /// as "ImageClassifier", "MobileNetV2" or "MyModel".
+ /// - model_asset_path: The relevant path to the model file in the asset.
+ // TODO(b/141225157): Automatically generate model_class_name.
+ GenerationResult Generate(const Model* model, const std::string& package_name,
+ const std::string& model_class_name,
+ const std::string& model_asset_path);
+
+ /// Generates files and returns the file paths and contents.
+ /// It's mostly identical with the previous one, but the model here is
+ /// provided as binary flatbuffer content without parsing.
+ GenerationResult Generate(const char* model_storage,
+ const std::string& package_name,
+ const std::string& model_class_name,
+ const std::string& model_asset_path);
+
+ std::string GetErrorMessage();
+
+ private:
+ const std::string module_root_;
+ ErrorReporter err_;
+};
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
diff --git a/tensorflow_lite_support/codegen/code_generator.cc b/tensorflow_lite_support/codegen/code_generator.cc
new file mode 100644
index 00000000..1337708d
--- /dev/null
+++ b/tensorflow_lite_support/codegen/code_generator.cc
@@ -0,0 +1,179 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/codegen/code_generator.h"
+
+#include <cctype>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "tensorflow_lite_support/codegen/utils.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+namespace {
+
+void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) {
+ auto& names = *names_ptr;
+ std::unordered_map<std::string, int> indexes;
+ std::unordered_map<std::string, int> first_appearance;
+ for (int i = 0; i < names.size(); i++) {
+ if (indexes.find(names[i]) == indexes.end()) {
+ indexes[names[i]] = 1;
+ first_appearance[names[i]] = i;
+ } else {
+ indexes[names[i]] += 1;
+ names[i].append(std::to_string(indexes[names[i]]));
+ }
+ }
+ for (const auto& it : first_appearance) {
+ const auto& name = it.first;
+ const auto i = it.second;
+ if (indexes[name] > 1) {
+ names[i].append("1");
+ }
+ }
+}
+
+} // namespace
+
+CodeGenerator::CodeGenerator() {}
+
+bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata,
+ ErrorReporter* err) {
+ if (metadata == nullptr) {
+ err->Error("Loading nullptr is not allowed");
+ return false;
+ }
+ if (metadata->subgraph_metadata()->size() != 1) {
+ err->Error("Only exact 1 subgraph is supported");
+ return false;
+ }
+ return true;
+}
+
+std::pair<std::vector<std::string>, std::vector<std::string>>
+CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs,
+ const TensorMetadataList* outputs) {
+ std::vector<std::string> input_names;
+ std::vector<std::string> output_names;
+ if (inputs != nullptr) {
+ input_names.reserve(inputs->size());
+ for (const auto* tensor : *inputs) {
+ input_names.push_back(NameTensor(*tensor, "input"));
+ }
+ }
+ if (outputs != nullptr) {
+ output_names.reserve(outputs->size());
+ for (const auto* tensor : *outputs) {
+ output_names.push_back(NameTensor(*tensor, "output"));
+ }
+ }
+ // Solve conflict
+ ResolveConflictedInputAndOutputNames(&input_names, &output_names);
+ return std::make_pair(input_names, output_names);
+}
+
+std::string CodeGenerator::ConvertToValidName(const std::string& name) {
+ // lowercase all
+ std::string result = name;
+ for (int i = 0; i < result.size(); i++) {
+ result[i] = std::tolower(result[i]);
+ }
+ // replace all non-alpha or non-numeric with underscores, except underscore
+ // itself
+ for (int i = 0; i < result.size(); i++) {
+ if (result[i] != '_' && !std::isalnum(result[i])) {
+ result[i] = '_';
+ }
+ }
+ // remove leading underscores
+ int leading_underscores = 0;
+ while (leading_underscores < result.size() &&
+ result[leading_underscores] == '_') {
+ leading_underscores++;
+ }
+ result.erase(0, leading_underscores);
+ if (result.empty()) {
+ return "";
+ }
+ // first char should be alpha
+ if (std::isalpha(result[0])) {
+ return result;
+ }
+ return "tensor_" + result;
+}
+
+std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
+ const std::string& default_name) {
+ if (tensor.name() != nullptr && tensor.name()->size() > 0) {
+ // TODO(b/141225157) Validate tensor name. It should be in lower case.
+ auto suggested_name = ConvertToValidName(tensor.name()->str());
+ if (!suggested_name.empty()) {
+ return suggested_name;
+ }
+ }
+ auto* content = tensor.content();
+ if (content == nullptr || content->content_properties() == nullptr) {
+ return default_name;
+ }
+ switch (content->content_properties_type()) {
+ case ContentProperties_ImageProperties:
+ return "image";
+ case ContentProperties_FeatureProperties:
+ return "feature";
+ default:
+ return default_name;
+ }
+}
+
+void CodeGenerator::ResolveConflictedInputAndOutputNames(
+ std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
+ std::unordered_set<std::string> io_conflict;
+ auto& input_names = *inputs;
+ auto& output_names = *outputs;
+ for (const auto& input : input_names) {
+ if (io_conflict.find(input) != io_conflict.end()) {
+ continue;
+ }
+ for (const auto& output : output_names) {
+ if (input == output) {
+ io_conflict.insert(input);
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < input_names.size(); i++) {
+ if (io_conflict.find(input_names[i]) != io_conflict.end()) {
+ input_names[i] = "input_" + input_names[i];
+ }
+ }
+ for (int i = 0; i < output_names.size(); i++) {
+ if (io_conflict.find(output_names[i]) != io_conflict.end()) {
+ output_names[i] = "output_" + output_names[i];
+ }
+ }
+ // 2. Second, add index if input[i] == input[j]
+ ResolveConflictedNamesByAddingIndex(&input_names);
+ ResolveConflictedNamesByAddingIndex(&output_names);
+}
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/codegen/code_generator.h b/tensorflow_lite_support/codegen/code_generator.h
new file mode 100644
index 00000000..b557773d
--- /dev/null
+++ b/tensorflow_lite_support/codegen/code_generator.h
@@ -0,0 +1,80 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_CODE_GENERATOR_H_
+#define TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_CODE_GENERATOR_H_
+
+#include <map>
+#include <memory>
+#include <sstream>
+#include <string>
+
+#include "tensorflow_lite_support/codegen/utils.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+struct GenerationResult {
+ struct File {
+ std::string path;
+ std::string content;
+ };
+ std::vector<File> files;
+};
+
+/// Defines language-independent codegen strategies, like class naming, .etc.
+/// Should not be used directly.
+class CodeGenerator {
+ public:
+ CodeGenerator();
+
+ using TensorMetadataList =
+ typename flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>;
+
+ virtual ~CodeGenerator() {}
+
+ // Strategies.
+ /// Names all the IO tensors. It's useful when they don't have names, or the
+ /// names have conflicts. We have to name every tensor for code generation.
+ // TODO(b/141225157): Add reserved keywords check.
+ static std::pair<std::vector<std::string>, std::vector<std::string>>
+ NameInputsAndOutputs(const TensorMetadataList* inputs,
+ const TensorMetadataList* outputs);
+
+ /// Loads a metadata for code generation.
+ /// Returns false if the metadata is not good for generation.
+ static bool VerifyMetadata(const ModelMetadata* metadata, ErrorReporter* err);
+
+ protected:
+ /// Converts a name into a valid form. Rules:
+ /// - lower all letters.
+ /// - replace all non alphabet nor numeric characters with underscores.
+ /// - remove prefix underscores.
+ /// - add prefix if the leading character is a number.
+ /// Returns empty string if not possible.
+ static std::string ConvertToValidName(const std::string& name);
+ static std::string NameTensor(const TensorMetadata& tensor,
+ const std::string& default_name);
+ static void ResolveConflictedInputAndOutputNames(
+ std::vector<std::string>* input, std::vector<std::string>* output);
+};
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_CODE_GENERATOR_H_
diff --git a/tensorflow_lite_support/codegen/code_generator_test.cc b/tensorflow_lite_support/codegen/code_generator_test.cc
new file mode 100644
index 00000000..5e9d64a0
--- /dev/null
+++ b/tensorflow_lite_support/codegen/code_generator_test.cc
@@ -0,0 +1,126 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/codegen/code_generator.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace support {
+namespace codegen {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class CodeGeneratorTest : public ::testing::Test {
+ public:
+ class TestingCodeGenerator : public CodeGenerator {
+ public:
+ explicit TestingCodeGenerator() : CodeGenerator() {}
+
+ // Make tested method public.
+ static std::string ConvertToValidName(const std::string& name) {
+ return CodeGenerator::ConvertToValidName(name);
+ }
+ static void ResolveConflictedInputAndOutputNames(
+ std::vector<std::string>* input, std::vector<std::string>* output) {
+ CodeGenerator::ResolveConflictedInputAndOutputNames(input, output);
+ }
+ };
+};
+
+TEST_F(CodeGeneratorTest, UpperCasesShouldLower) {
+ EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("AlphaBetCOOL"),
+ "alphabetcool");
+}
+
+TEST_F(CodeGeneratorTest, NonAlphaNumShouldReplace) {
+ EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("A+=B C\t"), "a__b_c_");
+}
+
+TEST_F(CodeGeneratorTest, NoLeadingUnderscore) {
+ EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("+KAI Z"), "kai_z");
+}
+
+TEST_F(CodeGeneratorTest, NoLeadingNumbers) {
+ EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("3000 Cool Tensors"),
+ "tensor_3000_cool_tensors");
+}
+
+TEST_F(CodeGeneratorTest, TestSimpleIONames) {
+ std::vector<std::string> inputs = {"image"};
+ std::vector<std::string> outputs = {"output"};
+ TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
+ EXPECT_THAT(inputs, ElementsAreArray({"image"}));
+ EXPECT_THAT(outputs, ElementsAreArray({"output"}));
+}
+
+TEST_F(CodeGeneratorTest, TestIOConflict) {
+ std::vector<std::string> inputs = {"image"};
+ std::vector<std::string> outputs = {"image"};
+ TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
+ EXPECT_THAT(inputs, ElementsAreArray({"input_image"}));
+ EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
+}
+
+TEST_F(CodeGeneratorTest, TestInternalConflict) {
+ std::vector<std::string> inputs = {"image", "image"};
+ std::vector<std::string> outputs = {"output"};
+ TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
+ EXPECT_THAT(inputs, ElementsAreArray({"image1", "image2"}));
+ EXPECT_THAT(outputs, ElementsAreArray({"output"}));
+}
+
+TEST_F(CodeGeneratorTest, TestAllConflictNTo1) {
+ std::vector<std::string> inputs = {"image", "image"};
+ std::vector<std::string> outputs = {"image"};
+ TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
+ EXPECT_THAT(inputs, ElementsAreArray({"input_image1", "input_image2"}));
+ EXPECT_THAT(outputs, ElementsAreArray({"output_image"}));
+}
+
+TEST_F(CodeGeneratorTest, TestAllConflict) {
+ std::vector<std::string> inputs = {"image", "audio", "image", "audio",
+ "audio"};
+ std::vector<std::string> outputs = {"image", "image", "audio", "feature",
+ "feature"};
+ TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
+ EXPECT_THAT(inputs,
+ ElementsAreArray({"input_image1", "input_audio1", "input_image2",
+ "input_audio2", "input_audio3"}));
+ EXPECT_THAT(outputs,
+ ElementsAreArray({"output_image1", "output_image2",
+ "output_audio", "feature1", "feature2"}));
+}
+
+TEST_F(CodeGeneratorTest, TestAllConflictReversed) {
+ std::vector<std::string> inputs = {"image", "image", "audio", "feature",
+ "feature"};
+ std::vector<std::string> outputs = {"image", "audio", "image", "audio",
+ "audio"};
+ TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs);
+ EXPECT_THAT(inputs,
+ ElementsAreArray({"input_image1", "input_image2", "input_audio",
+ "feature1", "feature2"}));
+ EXPECT_THAT(outputs, ElementsAreArray({"output_image1", "output_audio1",
+ "output_image2", "output_audio2",
+ "output_audio3"}));
+}
+
+} // namespace
+} // namespace codegen
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/codegen/metadata_helper.cc b/tensorflow_lite_support/codegen/metadata_helper.cc
new file mode 100644
index 00000000..00c97236
--- /dev/null
+++ b/tensorflow_lite_support/codegen/metadata_helper.cc
@@ -0,0 +1,100 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/codegen/metadata_helper.h"
+
+#include "tensorflow_lite_support/codegen/utils.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+constexpr char BUFFER_KEY[] = "TFLITE_METADATA";
+const ModelMetadata* GetMetadataFromModel(const Model* model) {
+ if (model == nullptr || model->metadata() == nullptr) {
+ return nullptr;
+ }
+ for (auto i = 0; i < model->metadata()->size(); i++) {
+ const auto* name = model->metadata()->Get(i)->name();
+ if (name != nullptr && name->str() == BUFFER_KEY) {
+ const auto buffer_index = model->metadata()->Get(i)->buffer();
+ if (model->buffers() == nullptr ||
+ model->buffers()->size() <= buffer_index) {
+ continue;
+ }
+ const auto* buffer_vec = model->buffers()->Get(buffer_index)->data();
+ if (buffer_vec == nullptr || buffer_vec->data() == nullptr) {
+ continue;
+ }
+ return GetModelMetadata(buffer_vec->data());
+ }
+ }
+ return nullptr;
+}
+
+int FindAssociatedFile(const TensorMetadata* metadata,
+ const AssociatedFileType file_type,
+ const std::string& tensor_identifier,
+ ErrorReporter* err) {
+ int result = -1;
+ if (metadata->associated_files() == nullptr ||
+ metadata->associated_files()->size() == 0) {
+ return result;
+ }
+ for (int i = 0; i < metadata->associated_files()->size(); i++) {
+ const auto* file_metadata = metadata->associated_files()->Get(i);
+ if (file_metadata->type() == file_type) {
+ if (result >= 0) {
+ err->Warning(
+ "Multiple associated file of type %d found on tensor %s. Only the "
+ "first one will be used.",
+ file_type, tensor_identifier.c_str());
+ continue;
+ }
+ result = i;
+ }
+ }
+ return result;
+}
+
+int FindNormalizationUnit(const TensorMetadata* metadata,
+ const std::string& tensor_identifier,
+ ErrorReporter* err) {
+ int result = -1;
+ if (metadata->process_units() == nullptr ||
+ metadata->process_units()->size() == 0) {
+ return result;
+ }
+ for (int i = 0; i < metadata->process_units()->size(); i++) {
+ const auto* process_uint = metadata->process_units()->Get(i);
+ if (process_uint->options_type() ==
+ ProcessUnitOptions_NormalizationOptions) {
+ if (result >= 0) {
+ err->Warning(
+ "Multiple normalization unit found in tensor %s. Only the first "
+ "one will be effective.",
+ tensor_identifier.c_str());
+ continue;
+ }
+ result = i;
+ }
+ }
+ return result;
+}
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/codegen/metadata_helper.h b/tensorflow_lite_support/codegen/metadata_helper.h
new file mode 100644
index 00000000..8e3dc6ab
--- /dev/null
+++ b/tensorflow_lite_support/codegen/metadata_helper.h
@@ -0,0 +1,51 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_METADATA_HELPER_H_
+#define TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_METADATA_HELPER_H_
+
+#include <string>
+
+#include "tensorflow_lite_support/codegen/utils.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+/// Parses a ModelMetadata out from a Model. The returned ModelMetadata's
+/// lifetime is scoped by the model.
+/// Returns nullptr if we cannot find any metadata.
+const ModelMetadata* GetMetadataFromModel(const Model* model);
+
+/// Finds an associated file from a TensorMetadata of certain type. If there're
+/// multiple files meet the criteria, only the first one is used. If there's no
+/// file meets the criteria, -1 will be returned.
+int FindAssociatedFile(const TensorMetadata* metadata,
+ const AssociatedFileType file_type,
+ const std::string& tensor_identifier,
+ ErrorReporter* err);
+
+/// Find the first normalization unit. If none, return -1.
+int FindNormalizationUnit(const TensorMetadata* metadata,
+ const std::string& tensor_identifier,
+ ErrorReporter* err);
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_METADATA_HELPER_H_
diff --git a/tensorflow_lite_support/codegen/python/BUILD b/tensorflow_lite_support/codegen/python/BUILD
new file mode 100644
index 00000000..ee4dcbbd
--- /dev/null
+++ b/tensorflow_lite_support/codegen/python/BUILD
@@ -0,0 +1,37 @@
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
+
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+pybind_extension(
+ name = "_pywrap_codegen",
+ srcs = [
+ "codegen_lib.cc",
+ ],
+ features = ["-use_header_modules"],
+ module_name = "_pywrap_codegen",
+ deps = [
+ "//tensorflow_lite_support/codegen:android_java_generator",
+ "//tensorflow_lite_support/codegen:code_generator",
+ "@local_config_python//:python_headers",
+ "@pybind11",
+ ],
+)
+
+py_binary(
+ name = "codegen",
+ srcs = [
+ "codegen.py",
+ ],
+ python_version = "PY3",
+ deps = [
+ ":_pywrap_codegen",
+ "@absl_py//absl:app",
+ "@absl_py//absl/flags",
+ "@absl_py//absl/logging",
+ ],
+)
diff --git a/tensorflow_lite_support/codegen/python/codegen.py b/tensorflow_lite_support/codegen/python/codegen.py
new file mode 100644
index 00000000..7309a69d
--- /dev/null
+++ b/tensorflow_lite_support/codegen/python/codegen.py
@@ -0,0 +1,104 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Generates Android Java sources from a TFLite model with metadata."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import sys
+from absl import app
+from absl import flags
+from absl import logging
+
+from tensorflow_lite_support.codegen.python import _pywrap_codegen
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.')
+flags.DEFINE_string('destination', None, 'Path of destination of generation.')
+flags.DEFINE_string('package_name', 'org.tensorflow.lite.support',
+ 'Name of generated java package to put the wrapper class.')
+flags.DEFINE_string(
+ 'model_class_name', 'MyModel',
+ 'Name of generated wrapper class (should not contain package name).')
+flags.DEFINE_string(
+ 'model_asset_path', '',
+ '(Optional) Path to the model in generated assets/ dir. If not set, '
+ 'generator will use base name of input model.'
+)
+
+
+def get_model_buffer(path):
+ if not os.path.isfile(path):
+ logging.error('Cannot find model at path %s.', path)
+ with open(path, 'rb') as f:
+ buf = f.read()
+ return buf
+
+
+def prepare_directory_for_file(file_path):
+ target_dir = os.path.dirname(file_path)
+ if not os.path.exists(target_dir):
+ os.makedirs(target_dir)
+ return
+ if not os.path.isdir(target_dir):
+ logging.error('Cannot write to %s', target_dir)
+
+
+def run_main(argv):
+ """Main function of the codegen."""
+
+ if len(argv) > 1:
+ logging.error('None flag arguments found: [%s]', ', '.join(argv[1:]))
+
+ codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination)
+ model_buffer = get_model_buffer(FLAGS.model)
+ model_asset_path = FLAGS.model_asset_path
+ if not model_asset_path:
+ model_asset_path = os.path.basename(FLAGS.model)
+ result = codegen.generate(model_buffer, FLAGS.package_name,
+ FLAGS.model_class_name, model_asset_path)
+ error_message = codegen.get_error_message().strip()
+ if error_message:
+ logging.error(error_message)
+ if not result.files:
+ logging.error('Generation failed!')
+ return
+
+ for each in result.files:
+ prepare_directory_for_file(each.path)
+ with open(each.path, 'w') as f:
+ f.write(each.content)
+
+ logging.info('Generation succeeded!')
+ model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets',
+ model_asset_path)
+ prepare_directory_for_file(model_asset_path)
+ shutil.copy(FLAGS.model, model_asset_path)
+ logging.info('Model copied into assets!')
+
+
+# Simple wrapper to make the code pip-friendly
+def main():
+ flags.mark_flag_as_required('model')
+ flags.mark_flag_as_required('destination')
+ app.run(main=run_main, argv=sys.argv)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/tensorflow_lite_support/codegen/python/codegen_lib.cc b/tensorflow_lite_support/codegen/python/codegen_lib.cc
new file mode 100644
index 00000000..6b2cd5ea
--- /dev/null
+++ b/tensorflow_lite_support/codegen/python/codegen_lib.cc
@@ -0,0 +1,49 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "pybind11/detail/common.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "pybind11/stl.h"
+#include "tensorflow_lite_support/codegen/android_java_generator.h"
+#include "tensorflow_lite_support/codegen/code_generator.h"
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+template <typename... Args>
+using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
+
+PYBIND11_MODULE(_pywrap_codegen, m) {
+ pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator")
+ .def(pybind11::init<const std::string &>())
+ .def("generate",
+ overload_cast_<const char *, const std::string &,
+ const std::string &, const std::string &>()(
+ &AndroidJavaGenerator::Generate))
+ .def("get_error_message", &AndroidJavaGenerator::GetErrorMessage);
+ pybind11::class_<GenerationResult>(m, "GenerationResult")
+ .def(pybind11::init<>())
+ .def_readwrite("files", &GenerationResult::files);
+ pybind11::class_<GenerationResult::File>(m, "GenerationResultFile")
+ .def(pybind11::init<>())
+ .def_readwrite("path", &GenerationResult::File::path)
+ .def_readwrite("content", &GenerationResult::File::content);
+}
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/codegen/utils.cc b/tensorflow_lite_support/codegen/utils.cc
new file mode 100644
index 00000000..c75fc5fa
--- /dev/null
+++ b/tensorflow_lite_support/codegen/utils.cc
@@ -0,0 +1,194 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_lite_support/codegen/utils.h"
+
+#include <cstdarg>
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+int ErrorReporter::Warning(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ return Report("[WARN] ", format, args);
+}
+
+int ErrorReporter::Error(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ return Report("[ERROR] ", format, args);
+}
+
+int ErrorReporter::Report(const char* prefix, const char* format,
+ va_list args) {
+ char buf[1024];
+ int formatted = vsnprintf(buf, sizeof(buf), format, args);
+ buffer_ << prefix << buf << std::endl;
+ return formatted;
+}
+
+std::string ErrorReporter::GetMessage() {
+ std::string value = buffer_.str();
+ buffer_.str("");
+ return value;
+}
+
+CodeWriter::CodeWriter(ErrorReporter* err) : indent_(0), err_(err) {}
+
+void CodeWriter::SetTokenValue(const std::string& token,
+ const std::string& value) {
+ value_map_[token] = value;
+}
+
+const std::string CodeWriter::GetTokenValue(const std::string& token) const {
+ auto iter = value_map_.find(token);
+ if (iter == value_map_.end()) {
+ // Typically only Code Generator's call this function (or `Append`). It's
+ // their duty to make sure the token is valid, and requesting for an invalid
+ // token implicits flaws in the code generation logic.
+ err_->Error("Internal: Cannot find value with token '%s'", token.c_str());
+ return "";
+ }
+ return iter->second;
+}
+
+void CodeWriter::SetIndentString(const std::string& indent_str) {
+ indent_str_ = indent_str;
+}
+
+void CodeWriter::Indent() { indent_++; }
+
+void CodeWriter::Outdent() { indent_--; }
+
+std::string CodeWriter::GenerateIndent() const {
+ std::string res;
+ res.reserve(indent_str_.size() * indent_);
+ for (int i = 0; i < indent_; i++) {
+ res.append(indent_str_);
+ }
+ return res;
+}
+
+void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); }
+
+void CodeWriter::AppendNoNewLine(const std::string& text) {
+ AppendInternal(text, false);
+}
+
+void CodeWriter::AppendInternal(const std::string& text, bool newline) {
+ // Prefix indent
+ if ((buffer_.empty() // nothing in the buffer
+ || buffer_.back() == '\n') // is on new line
+ // is writing on current line
+ && (!text.empty() && text[0] != '\n' && text[0] != '\r')) {
+ buffer_.append(GenerateIndent());
+ }
+ // State machine variables
+ bool in_token = false;
+ int i = 0;
+ // Rough memory reserve
+ buffer_.reserve(buffer_.size() + text.size());
+ std::string token_buffer;
+ // A simple LL1 analysis
+ while (i < text.size()) {
+ char cur = text[i];
+ char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian
+ if (!in_token) {
+ if (cur == '{' && cur_next == '{') { // Enter token
+ in_token = true;
+ i += 2;
+ } else if (cur == '\n') { // We need to apply global indent here
+ buffer_.push_back(cur);
+ if (cur_next != '\0' && cur_next != '\n' && cur_next != '\r') {
+ buffer_.append(GenerateIndent());
+ }
+ i += 1;
+ } else {
+ buffer_.push_back(cur);
+ i += 1;
+ }
+ } else {
+ if (cur == '}' && cur_next == '}') { // Close token
+ in_token = false;
+ const auto value = GetTokenValue(token_buffer);
+ buffer_.append(value);
+ token_buffer.clear();
+ i += 2;
+ } else {
+ token_buffer.push_back(cur);
+ i += 1;
+ }
+ }
+ }
+ if (!token_buffer.empty()) {
+ // Typically only Code Generator's call this function. It's
+ // their duty to make sure the code (or template) has valid syntax, and
+ // unclosed "{{...}}" implicits severe error in the template.
+ err_->Error("Internal: Invalid template: {{token}} is not closed.");
+ }
+ if (newline) {
+ buffer_.push_back('\n');
+ }
+}
+
+void CodeWriter::NewLine() { Append(""); }
+
+void CodeWriter::Backspace(int n) {
+ buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0);
+}
+
+std::string CodeWriter::ToString() const { return buffer_; }
+
+bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); }
+
+void CodeWriter::Clear() {
+ buffer_.clear();
+ value_map_.clear();
+ indent_ = 0;
+}
+
+std::string SnakeCaseToCamelCase(const std::string& s) {
+ std::string t;
+ t.reserve(s.length());
+ size_t i = 0;
+ // Note: Use simple string += for simplicity.
+ bool cap = false;
+ while (i < s.size()) {
+ const char c = s[i++];
+ if (c == '_') {
+ cap = true;
+ } else if (cap) {
+ t += toupper(c);
+ cap = false;
+ } else {
+ t += c;
+ }
+ }
+ return t;
+}
+
+std::string JoinPath(const std::string& a, const std::string& b) {
+ if (a.empty()) return b;
+ std::string a_fixed = a;
+ if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back();
+ std::string b_fixed = b;
+ if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1);
+ return a_fixed + "/" + b_fixed;
+}
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/codegen/utils.h b/tensorflow_lite_support/codegen/utils.h
new file mode 100644
index 00000000..98768b6a
--- /dev/null
+++ b/tensorflow_lite_support/codegen/utils.h
@@ -0,0 +1,127 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_UTILS_H_
+#define TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_UTILS_H_
+
+#include <map>
+#include <sstream>
+#include <string>
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+/// Collects runtime error logs which could be showed later.
+// TODO(b/150538286): Consider a better mechanism to simplify callsite code.
+class ErrorReporter {
+ public:
+ int Warning(const char* format, ...);
+ int Error(const char* format, ...);
+ std::string GetMessage();
+
+ private:
+ int Report(const char* prefix, const char* format, va_list args);
+ std::stringstream buffer_;
+};
+
+/// Implements basic code generating with text templates.
+///
+/// It could accept code templates and concatenate them into complete codes. A
+/// template could contain named values.
+///
+/// Example code:
+/// CodeWriter code;
+/// code.SetValue("NAME", "Foo");
+/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
+/// code.SetValue("NAME", "Bar");
+/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }");
+///
+/// Output:
+/// void Foo() { printf("%s", "Foo"); }
+/// void Bar() { printf("%s", "Bar"); }
+class CodeWriter {
+ public:
+ explicit CodeWriter(ErrorReporter* err);
+ /// Sets value to a token. When generating code with template, a string in a
+ /// pair of {{ and }} will be regarded as a token and replaced with the
+ /// corresponding value in code generation.
+ /// It rewrites if the token already has a value.
+ void SetTokenValue(const std::string& token, const std::string& value);
+
+ /// Gets the current value set on the given token.
+ const std::string GetTokenValue(const std::string& token) const;
+
+ /// Sets the unit indent string. For example, in Java it should be " ".
+ void SetIndentString(const std::string& indent);
+
+ /// Increases the indent by a unit (the string set in SetIndentString).
+ void Indent();
+
+ /// Decreases the indent by a unit (the string set in SetIndentString).
+ void Outdent();
+
+ /// Generates the indentation string.
+ std::string GenerateIndent() const;
+
+ /// Appends a piece of template codes to the stream. Every named value will be
+ /// replaced via the real value. A new line will always be appended at the
+ /// end.
+ void Append(const std::string& text);
+
+ /// Appends a piece of template codes to the stream. Same with `Append`, but a
+ /// new line will not be appended at the end.
+ void AppendNoNewLine(const std::string& text);
+
+ /// Appends a new line to the stream.
+ void NewLine();
+
+ /// Deletes the last N charaters in the stream. If the stream has less than N
+ /// characters, deletes all.
+ void Backspace(int n);
+
+ std::string ToString() const;
+
+ /// Checks if the internal string stream is empty. Note: This method has
+ // overhead.
+ bool IsStreamEmpty() const;
+
+ /// Clears all the internal string stream and value map.
+ void Clear();
+
+ private:
+ void AppendInternal(const std::string& text, bool newline);
+
+ std::string indent_str_;
+ int indent_;
+
+ std::map<std::string, std::string> value_map_;
+ std::string buffer_;
+
+ ErrorReporter* err_;
+};
+
+/// Converts foo_bar_name to fooBarName. It's callers duty to make sure given
+/// string "s" is already in snake case; or unexpected behavior may occur.
+std::string SnakeCaseToCamelCase(const std::string& s);
+
+/// Joins 2 parts of file path into one, connected by unix path seperator '/'.
+/// It's callers duty to ensure the two parts are valid.
+std::string JoinPath(const std::string& a, const std::string& b);
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite
+
+#endif // TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_UTILS_H_
diff --git a/tensorflow_lite_support/codegen/utils_test.cc b/tensorflow_lite_support/codegen/utils_test.cc
new file mode 100644
index 00000000..3111f37b
--- /dev/null
+++ b/tensorflow_lite_support/codegen/utils_test.cc
@@ -0,0 +1,97 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/codegen/utils.h"
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace support {
+namespace codegen {
+namespace {
+
+TEST(ErrorReporterTest, TestReportError) {
+ ErrorReporter err;
+ err.Error("some text");
+ EXPECT_EQ(err.GetMessage(), "[ERROR] some text\n");
+ EXPECT_EQ(err.GetMessage(), "");
+}
+
+TEST(CodeGeneratorTest, TestExample) {
+ ErrorReporter err;
+ CodeWriter writer(&err);
+ writer.SetTokenValue("NAME", "Foo");
+ const std::string text = R"(void {{NAME}}() { printf("%s", "{{NAME}}"); })";
+ writer.Append(text);
+ writer.SetTokenValue("NAME", "Bar");
+ writer.Append(text);
+ EXPECT_EQ(
+ "void Foo() { printf(\"%s\", \"Foo\"); }\n"
+ "void Bar() { printf(\"%s\", \"Bar\"); }\n",
+ writer.ToString());
+}
+
+TEST(CodeGeneratorTest, TestInexistentToken) {
+ ErrorReporter err;
+ CodeWriter writer(&err);
+ writer.SetTokenValue("NAME", "Foo");
+ const std::string text = R"(void {{name}}() {})";
+ writer.Append(text);
+ EXPECT_EQ(err.GetMessage(),
+ "[ERROR] Internal: Cannot find value with token 'name'\n");
+}
+
+TEST(CodeGeneratorTest, TestUnclosedToken) {
+ ErrorReporter err;
+ CodeWriter writer(&err);
+ writer.SetTokenValue("NAME", "Foo");
+ const std::string text = R"(void {{NAME}() {})";
+ writer.Append(text);
+ EXPECT_EQ(err.GetMessage(),
+ "[ERROR] Internal: Invalid template: {{token}} is not closed.\n");
+}
+
+TEST(CodeGeneratorTest, TestIndentControl) {
+ ErrorReporter err;
+ CodeWriter writer(&err);
+ writer.SetIndentString(" ");
+ writer.Indent();
+ writer.AppendNoNewLine("abcde"); // Will indent
+ EXPECT_EQ(" abcde", writer.ToString());
+ writer.Clear();
+ writer.Indent();
+ writer.AppendNoNewLine("abc\n\nde");
+ // The blank line will not indent
+ EXPECT_EQ(" abc\n\n de", writer.ToString());
+ writer.Clear();
+ writer.Indent();
+ writer.Append("abc");
+ writer.Outdent();
+ writer.AppendNoNewLine("def");
+ EXPECT_EQ(" abc\ndef", writer.ToString());
+}
+
+TEST(CaseConversionTest, TestSnakeToCamel) {
+ EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel"));
+ EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel_"));
+ EXPECT_EQ("ImACamel", SnakeCaseToCamelCase("_im_a_camel"));
+ EXPECT_EQ("", SnakeCaseToCamelCase("_"));
+ EXPECT_EQ("camel", SnakeCaseToCamelCase("camel"));
+}
+
+} // namespace
+} // namespace codegen
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/BUILD b/tensorflow_lite_support/custom_ops/BUILD
new file mode 100644
index 00000000..35734cf2
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/BUILD
@@ -0,0 +1,43 @@
+load("@org_tensorflow//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# This will generate the tf_text_flex_delegate cc_library, which is a custom
+# flex delegate that only contains ops in listed models.
+tflite_flex_cc_library(
+ name = "tf_text_flex_delegate",
+ additional_deps = ["@org_tensorflow_text//tensorflow_text:ops_lib"],
+ models = [
+ # TODO(b/160817619) Replace with a more complex model.
+ "testdata/sentencepiece_tokenizer_flex_op.tflite",
+ ],
+)
+
+# bazel test --config=monolithic tensorflow_lite_support/custom_ops:tflite_inference_test
+cc_test(
+ name = "tflite_inference_test",
+ srcs = ["tflite_inference_main.cc"],
+ args = ["--model=tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflite"],
+ data = ["//tensorflow_lite_support/custom_ops:testdata/sentencepiece_tokenizer_flex_op.tflite"],
+ deps = [
+ ":tf_text_flex_delegate",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite:string_util",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ "@org_tensorflow//tensorflow/lite/tools:command_line_flags",
+ ] + select({
+ "@org_tensorflow//tensorflow:android": [
+ "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite",
+ ],
+ "@org_tensorflow//tensorflow:ios": [
+ "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+ }),
+)
diff --git a/tensorflow_lite_support/custom_ops/kernel/BUILD b/tensorflow_lite_support/custom_ops/kernel/BUILD
new file mode 100644
index 00000000..b9b11de9
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/BUILD
@@ -0,0 +1,146 @@
+# Placeholder for internal Python strict test compatibility macro.
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "whitespace_tokenizer",
+ srcs = ["whitespace_tokenizer.cc"],
+ hdrs = ["whitespace_tokenizer.h"],
+ deps = [
+ "@org_tensorflow//tensorflow/lite:context",
+ "@org_tensorflow//tensorflow/lite:string_util",
+ "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
+ "@utf_archive//:utf",
+ ],
+)
+
+cc_library(
+ name = "whitespace_tokenizer_op_resolver",
+ srcs = ["whitespace_tokenizer_op_resolver.cc"],
+ hdrs = ["whitespace_tokenizer_op_resolver.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":whitespace_tokenizer",
+ "@org_tensorflow//tensorflow/lite:framework",
+ ],
+)
+
+pybind_extension(
+ name = "_pywrap_whitespace_tokenizer_op_resolver",
+ srcs = ["whitespace_tokenizer_op_resolver_wrapper.cc"],
+ hdrs = ["whitespace_tokenizer_op_resolver.h"],
+ additional_exported_symbols = ["AddWhitespaceTokenizerCustomOp"],
+ module_name = "_pywrap_whitespace_tokenizer_op_resolver",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":whitespace_tokenizer_op_resolver",
+ "@local_config_python//:python_headers",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@pybind11",
+ ],
+)
+
+cc_test(
+ name = "whitespace_tokenizer_test",
+ srcs = ["whitespace_tokenizer_test.cc"],
+ deps = [
+ ":whitespace_tokenizer",
+ "@com_google_googletest//:gtest_main",
+ "@org_tensorflow//tensorflow/lite:string_util",
+ "@org_tensorflow//tensorflow/lite/kernels:test_util",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
+
+py_test(
+ name = "whitespace_tokenizer_py_test",
+ srcs = ["whitespace_tokenizer_test.py"],
+ data = [
+ "testdata/whitespace_tokenizer_flex_delegate.tflite",
+ "testdata/whitespace_tokenizer_to_ragged_1d_input.tflite",
+ "testdata/whitespace_tokenizer_to_ragged_2d_input.tflite",
+ "testdata/whitespace_tokenizer_to_tensor.tflite",
+ ],
+ main = "whitespace_tokenizer_test.py",
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_pywrap_whitespace_tokenizer_op_resolver",
+ # numpy dep,
+ # tensorflow dep,
+ # tensorflow_text dep,
+ "@absl_py//absl/logging",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cc_library(
+ name = "ngrams",
+ srcs = ["ngrams.cc"],
+ hdrs = ["ngrams.h"],
+ deps = [
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:context",
+ "@org_tensorflow//tensorflow/lite:string_util",
+ "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
+ ],
+)
+
+cc_library(
+ name = "ngrams_op_resolver",
+ srcs = ["ngrams_op_resolver.cc"],
+ hdrs = ["ngrams_op_resolver.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":ngrams",
+ "@org_tensorflow//tensorflow/lite:framework",
+ ],
+)
+
+pybind_extension(
+ name = "_pywrap_ngrams_op_resolver",
+ srcs = ["ngrams_op_resolver_wrapper.cc"],
+ hdrs = ["ngrams_op_resolver.h"],
+ additional_exported_symbols = ["AddNgramsCustomOp"],
+ module_name = "_pywrap_ngrams_op_resolver",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":ngrams_op_resolver",
+ "@local_config_python//:python_headers",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@pybind11",
+ ],
+)
+
+cc_test(
+ name = "ngrams_test",
+ srcs = ["ngrams_test.cc"],
+ deps = [
+ ":ngrams",
+ "@com_google_googletest//:gtest_main",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:string_util",
+ "@org_tensorflow//tensorflow/lite/kernels:test_util",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
+
+py_test(
+ name = "ngrams_py_test",
+ srcs = ["ngrams_test.py"],
+ main = "ngrams_test.py",
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":_pywrap_ngrams_op_resolver",
+ # tensorflow dep,
+ # tensorflow_text dep,
+ "//tensorflow_lite_support/custom_ops/python:tflite_text_api",
+ "@absl_py//absl/logging",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams.cc b/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
new file mode 100644
index 00000000..3831c63c
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ngrams.cc
@@ -0,0 +1,208 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h"
+
+#include "flatbuffers/flexbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace ngrams {
+
+// This TFLite op implements the text.ngrams when reduction_type = STRING_JOIN.
+//
+// Input:
+// * data: A string tensor, or a ragged string tensor (a 1D string value tensor
+// and one or more 1D int64 row_split tensors).
+//
+// Attributes:
+// * width: scalar integer
+// The width of the ngram window.
+// * axis: scalar integer
+// The axis to create ngrams along. For STRING_JOIN, this must be -1.
+// * reduction_type: scalar string
+// A string corresponding to the name of an enum value of text.Reduction
+// Currently, only STRING_JOIN is supported.
+// * string_separator: scalar string
+// The separator string used to join tokens together.
+//
+// Output:
+// * output: A string tensor that matches the rank of 'data'. Will be a ragged
+// tensor if 'data' is a ragged tensor.
+
+// Both the input and output tensors use the same indices.
+constexpr int kValues = 0;
+constexpr int kRowSplitsStart = 1;
+
+// Reduction types.
+constexpr char kStringJoin[] = "STRING_JOIN";
+
+struct NgramsAttributes {
+ int width;
+ int axis;
+ std::string reduction_type;
+ std::string string_separator;
+
+ explicit NgramsAttributes(const flexbuffers::Map& m)
+ : width(m["width"].AsInt32()),
+ axis(m["axis"].AsInt32()),
+ reduction_type(m["reduction_type"].ToString()),
+ string_separator(m["string_separator"].ToString()) {}
+};
+
+inline bool OutputIsTensor(TfLiteNode* node) { return NumOutputs(node) == 1; }
+inline int NumRowSplits(TfLiteNode* node) {
+ return NumInputs(node) - kRowSplitsStart;
+}
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ return new NgramsAttributes(flexbuffers::GetRoot(buffer_t, length).AsMap());
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<NgramsAttributes*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto& attributes =
+ *reinterpret_cast<NgramsAttributes*>(node->user_data);
+
+ TF_LITE_ENSURE(context, attributes.reduction_type == kStringJoin);
+ TF_LITE_ENSURE(context, attributes.axis == -1);
+
+ TfLiteTensor* output_values = GetOutput(context, node, kValues);
+ if (OutputIsTensor(node)) {
+ const TfLiteTensor* input_values = GetInput(context, node, kValues);
+ int values_num_dims = NumDimensions(input_values);
+ TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(values_num_dims);
+ for (int i = 0; i < values_num_dims; ++i) {
+ output_values_shape->data[i] = SizeOfDimension(input_values, i);
+ }
+ output_values_shape->data[values_num_dims - 1] =
+ std::max(0, SizeOfDimension(input_values, values_num_dims - 1) -
+ attributes.width + 1);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_values,
+ output_values_shape));
+ return kTfLiteOk;
+ }
+
+ SetTensorToDynamic(output_values);
+ // The row_splits tensors maintain their shape, because only the
+ // innermost dimension will change.
+ for (int i = kRowSplitsStart; i < NumOutputs(node); ++i) {
+ const TfLiteTensor* input_row_splits = GetInput(context, node, i);
+ TfLiteTensor* output_row_splits = GetOutput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input_row_splits), 1);
+ TfLiteIntArray* output_row_splits_shape = TfLiteIntArrayCreate(1);
+ output_row_splits_shape->data[0] = SizeOfDimension(input_row_splits, 0);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_row_splits,
+ output_row_splits_shape));
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto& attributes =
+ *reinterpret_cast<NgramsAttributes*>(node->user_data);
+
+ // Storage for the dummy input and output row_splits used in the tensor case.
+ std::vector<int64_t> tensor_input_row_splits;
+ std::vector<int64_t> tensor_output_row_splits;
+
+ const int64_t* input_row_splits;
+ int64_t* output_row_splits;
+ int n_row_splits = 0;
+
+ const TfLiteTensor* input_values = GetInput(context, node, kValues);
+
+ if (OutputIsTensor(node)) {
+ // Generate mock input and output innermost row_splits.
+ int64_t total_tokens = NumElements(input_values);
+ int64_t tokens_per_element =
+ SizeOfDimension(input_values, NumDimensions(input_values) - 1);
+ tensor_input_row_splits.reserve(total_tokens / tokens_per_element + 1);
+ tensor_output_row_splits.resize(total_tokens / tokens_per_element + 1);
+ for (int64_t i = 0; i <= total_tokens; i += tokens_per_element) {
+ tensor_input_row_splits.push_back(i);
+ }
+ input_row_splits = tensor_input_row_splits.data();
+ output_row_splits = tensor_output_row_splits.data();
+ n_row_splits = tensor_input_row_splits.size();
+ } else {
+ int index = 0;
+ while (index < NumRowSplits(node) - 1) {
+ const TfLiteTensor* input_tensor_row_splits =
+ GetInput(context, node, kRowSplitsStart + index);
+ TfLiteTensor* output_tensor_row_splits =
+ GetOutput(context, node, kRowSplitsStart + index);
+ memcpy(output_tensor_row_splits->data.raw,
+ input_tensor_row_splits->data.raw, input_tensor_row_splits->bytes);
+ ++index;
+ }
+
+ const TfLiteTensor* input_tensor_row_splits =
+ GetInput(context, node, kRowSplitsStart + index);
+ TfLiteTensor* output_tensor_row_splits =
+ GetOutput(context, node, kRowSplitsStart + index);
+ input_row_splits = input_tensor_row_splits->data.i64;
+ output_row_splits = output_tensor_row_splits->data.i64;
+ n_row_splits = SizeOfDimension(input_tensor_row_splits, 0);
+ }
+
+ DynamicBuffer buffer;
+ StringRef separator;
+ separator.str = attributes.string_separator.c_str();
+ separator.len = attributes.string_separator.length();
+ int buffer_index = 0;
+ for (int i = 0; i < n_row_splits - 1; ++i) {
+ output_row_splits[i] = buffer_index;
+ std::vector<StringRef> tokens;
+ for (int j = input_row_splits[i]; j < input_row_splits[i + 1]; ++j) {
+ tokens.emplace_back(GetString(input_values, j));
+ if (tokens.size() < attributes.width) continue;
+ tokens.erase(tokens.begin(),
+ tokens.begin() + tokens.size() - attributes.width);
+ buffer.AddJoinedString(tokens, separator);
+ ++buffer_index;
+ }
+ }
+ output_row_splits[n_row_splits - 1] = buffer_index;
+
+ TfLiteTensor* output_values = GetOutput(context, node, kValues);
+ if (OutputIsTensor(node)) {
+ buffer.WriteToTensor(output_values, /*new_shape=*/nullptr);
+ } else {
+ buffer.WriteToTensorAsVector(output_values);
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace ngrams
+
+TfLiteRegistration* Register_tftext_Ngrams() {
+ static TfLiteRegistration r = {ngrams::Init, ngrams::Free, ngrams::Prepare,
+ ngrams::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams.h b/tensorflow_lite_support/custom_ops/kernel/ngrams.h
new file mode 100644
index 00000000..56229065
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ngrams.h
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_tftext_Ngrams();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
new file mode 100644
index 00000000..b87fcac3
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h"
+
+#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+void AddNgramsCustomOp(MutableOpResolver* resolver) {
+ resolver->AddCustom("tftext:Ngrams", Register_tftext_Ngrams());
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h
new file mode 100644
index 00000000..fc932688
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h
@@ -0,0 +1,34 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_OP_RESOLVER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_OP_RESOLVER_H_
+
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+// Adds the Ngrams custom op to an op resolver.
+// This function can be loaded using dlopen. Since C++ function names get
+// mangled, declare this function as extern C, so its name is unchanged.
+extern "C" void AddNgramsCustomOp(MutableOpResolver* resolver);
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_OP_RESOLVER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver_wrapper.cc b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver_wrapper.cc
new file mode 100644
index 00000000..82747309
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver_wrapper.cc
@@ -0,0 +1,29 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "pybind11/pybind11.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
+#include "tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h"
+
+PYBIND11_MODULE(_pywrap_ngrams_op_resolver, m) {
+ m.doc() = "_pywrap_ngrams_op_resolver";
+ m.def(
+ "AddNgramsCustomOp",
+ [](uintptr_t resolver) {
+ tflite::ops::custom::AddNgramsCustomOp(
+ reinterpret_cast<tflite::MutableOpResolver*>(resolver));
+ },
+ "Op registerer function for the tftext:Ngrams custom op.");
+}
diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc b/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
new file mode 100644
index 00000000..91ef47af
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc
@@ -0,0 +1,293 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h"
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace ngrams {
+namespace test {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+} // namespace
+
+class NgramsModel : public SingleOpModel {
+ public:
+ // Constructor for testing the op with a tf.Tensor
+ NgramsModel(int width, const std::string& string_separator,
+ const std::vector<std::string>& input_values,
+ const std::vector<int>& input_shape) {
+ input_values_ = AddInput(TensorType_STRING);
+ output_values_ = AddOutput(TensorType_STRING);
+
+ BuildCustomOp(width, string_separator);
+
+ BuildInterpreter({input_shape});
+ PopulateStringTensor(input_values_, input_values);
+ Invoke();
+ }
+
+ // Constructor for the op with a tf.RaggedTensor
+ // Note: This interface uses row_lengths, as they're closer to the
+ // dimensions in a TensorShape, but internally everything is row_splits.
+ NgramsModel(int width, const std::string& string_separator,
+ const std::vector<std::string>& input_values,
+ const std::vector<std::vector<int64_t>> nested_row_lengths) {
+ std::vector<std::vector<int>> input_shapes;
+ input_shapes.reserve(nested_row_lengths.size() + 1);
+
+ input_values_ = AddInput(TensorType_STRING);
+ input_shapes.push_back({static_cast<int>(input_values.size())});
+ output_values_ = AddOutput(TensorType_STRING);
+
+ input_row_splits_.reserve(nested_row_lengths.size());
+ output_row_splits_.reserve(nested_row_lengths.size());
+ for (int i = 0; i < nested_row_lengths.size(); ++i) {
+ input_row_splits_.push_back(AddInput(TensorType_INT64));
+ input_shapes.push_back(
+ {static_cast<int>(nested_row_lengths[i].size() + 1)});
+ output_row_splits_.push_back(AddOutput(TensorType_INT64));
+ }
+
+ BuildCustomOp(width, string_separator);
+
+ BuildInterpreter(input_shapes);
+ PopulateStringTensor(input_values_, input_values);
+ for (int i = 0; i < nested_row_lengths.size(); ++i) {
+ std::vector<int64_t> row_splits;
+ row_splits.reserve(nested_row_lengths[i].size() + 1);
+ int64_t index = 0;
+ row_splits.push_back(index);
+ for (int64_t row_length : nested_row_lengths[i]) {
+ index += row_length;
+ row_splits.push_back(index);
+ }
+ PopulateTensor(input_row_splits_[i], row_splits);
+ }
+ Invoke();
+ }
+
+ std::vector<int> GetValuesTensorShape() {
+ return GetTensorShape(output_values_);
+ }
+
+ std::vector<std::string> ExtractValuesTensorVector() {
+ std::vector<std::string> r;
+ TfLiteTensor* tensor = interpreter_->tensor(output_values_);
+ int n = GetStringCount(tensor);
+ for (int i = 0; i < n; ++i) {
+ StringRef ref = GetString(tensor, i);
+ r.emplace_back(ref.str, ref.len);
+ }
+ return r;
+ }
+
+ int GetNumNestedRowLengths() { return output_row_splits_.size(); }
+
+ std::vector<int> GetRowLengthsTensorShape(int i) {
+ std::vector<int> shape = GetTensorShape(output_row_splits_[i]);
+ --shape[0];
+ return shape;
+ }
+
+ std::vector<int64_t> ExtractRowLengthsTensorVector(int i) {
+ std::vector<int64_t> row_splits =
+ ExtractVector<int64_t>(output_row_splits_[i]);
+ std::vector<int64_t> row_lengths;
+ row_lengths.reserve(row_splits.size() - 1);
+ int64_t head = row_splits[0];
+ for (int i = 1; i < row_splits.size(); ++i) {
+ int64_t tail = row_splits[i];
+ row_lengths.push_back(tail - head);
+ head = tail;
+ }
+ return row_lengths;
+ }
+
+ private:
+ void BuildCustomOp(int width, const std::string& string_separator) {
+ flexbuffers::Builder fbb;
+ size_t start_map = fbb.StartMap();
+ fbb.Int("width", width);
+ fbb.String("string_separator", string_separator);
+ fbb.Int("axis", -1);
+ fbb.String("reduction_type", "STRING_JOIN");
+ fbb.EndMap(start_map);
+ fbb.Finish();
+
+ SetCustomOp("tftext:Ngrams", fbb.GetBuffer(), Register_tftext_Ngrams);
+ }
+
+ int input_values_;
+ std::vector<int> input_row_splits_;
+ int output_values_;
+ std::vector<int> output_row_splits_;
+};
+
+TEST(NgramsTest, TensorSingleSequenceWidthTwo) {
+ NgramsModel m(2, " ", {"this", "is", "a", "test"}, std::vector<int>{4});
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3));
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("this is", "is a", "a test"));
+}
+
+TEST(NgramsTest, TensorSingleSequenceWidthThree) {
+ NgramsModel m(3, " ", {"this", "is", "a", "test"}, std::vector<int>{4});
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(2));
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("this is a", "is a test"));
+}
+
+TEST(NgramsTest, TensorSingleSequenceLongerSeparator) {
+ NgramsModel m(2, "...", {"this", "is", "a", "test"}, std::vector<int>{4});
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3));
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("this...is", "is...a", "a...test"));
+}
+
+TEST(NgramsTest, TensorSingleSequenceWidthTooLong) {
+ NgramsModel m(5, " ", {"this", "is", "a", "test"}, std::vector<int>{4});
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(0));
+ EXPECT_THAT(m.ExtractValuesTensorVector(), ElementsAre());
+}
+
+TEST(NgramsTest, TensorMultidimensionalInputWidthTwo) {
+ NgramsModel m(2, " ",
+ {
+ "0,0,0", "0,0,1", "0,0,2", "0,0,3", //
+ "0,1,0", "0,1,1", "0,1,2", "0,1,3", //
+ "0,2,0", "0,2,1", "0,2,2", "0,2,3", //
+ "1,0,0", "1,0,1", "1,0,2", "1,0,3", //
+ "1,1,0", "1,1,1", "1,1,2", "1,1,3", //
+ "1,2,0", "1,2,1", "1,2,2", "1,2,3", //
+ },
+ std::vector<int>{2, 3, 4});
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(2, 3, 3));
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAreArray({
+ "0,0,0 0,0,1", "0,0,1 0,0,2", "0,0,2 0,0,3", //
+ "0,1,0 0,1,1", "0,1,1 0,1,2", "0,1,2 0,1,3", //
+ "0,2,0 0,2,1", "0,2,1 0,2,2", "0,2,2 0,2,3", //
+ "1,0,0 1,0,1", "1,0,1 1,0,2", "1,0,2 1,0,3", //
+ "1,1,0 1,1,1", "1,1,1 1,1,2", "1,1,2 1,1,3", //
+ "1,2,0 1,2,1", "1,2,1 1,2,2", "1,2,2 1,2,3", //
+ }));
+}
+
+TEST(NgramsTest, RaggedTensorSingleSequenceWidthTwo) {
+ std::vector<std::vector<int64_t>> nested_row_lengths;
+ nested_row_lengths.push_back({4});
+ NgramsModel m(2, " ", {"this", "is", "a", "test"},
+ nested_row_lengths);
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3));
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("this is", "is a", "a test"));
+ ASSERT_THAT(m.GetNumNestedRowLengths(), 1);
+ EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1));
+ EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(3));
+}
+
+TEST(NgramsTest, RaggedTensorSingleSequenceWidthThree) {
+ std::vector<std::vector<int64_t>> nested_row_lengths;
+ nested_row_lengths.push_back({4});
+ NgramsModel m(3, " ", {"this", "is", "a", "test"}, nested_row_lengths);
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(2));
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("this is a", "is a test"));
+ ASSERT_THAT(m.GetNumNestedRowLengths(), 1);
+ EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1));
+ EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(2));
+}
+
+TEST(NgramsTest, RaggedTensorSingleSequenceLongerSeparator) {
+ std::vector<std::vector<int64_t>> nested_row_lengths;
+ nested_row_lengths.push_back({4});
+ NgramsModel m(2, "<>", {"this", "is", "a", "test"}, nested_row_lengths);
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3));
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("this<>is", "is<>a", "a<>test"));
+ ASSERT_THAT(m.GetNumNestedRowLengths(), 1);
+ EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1));
+ EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(3));
+}
+
+TEST(NgramsTest, RaggedTensorSingleSequenceWidthTooLong) {
+ std::vector<std::vector<int64_t>> nested_row_lengths;
+ nested_row_lengths.push_back({4});
+ NgramsModel m(5, " ", {"this", "is", "a", "test"}, nested_row_lengths);
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(0));
+ EXPECT_THAT(m.ExtractValuesTensorVector(), ElementsAre());
+ ASSERT_THAT(m.GetNumNestedRowLengths(), 1);
+ EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1));
+ EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(0));
+}
+
+TEST(NgramsTest, RaggedTensorMultidimensionalInputWidthTwo) {
+ std::vector<std::vector<int64_t>> nested_row_lengths;
+ nested_row_lengths.push_back({4, 2, 1});
+ nested_row_lengths.push_back({5, 4, 3, 2, 2, 3, 4, 6});
+ NgramsModel m(2, " ",
+ {
+ "0,0,0", "0,0,1", "0,0,2", "0,0,3", "0,0,4", //
+ "0,1,0", "0,1,1", "0,1,2", "0,1,3", //
+ "0,2,0", "0,2,1", "0,2,2", //
+ "0,3,0", "0,3,1", //
+ "1,0,0", "1,0,1", //
+ "1,1,0", "1,1,1", "1,1,2", //
+ "1,2,0", "1,2,1", "1,2,2", "1,2,3", //
+ "2,0,0", "2,0,1", "2,0,2", "2,0,3", "2,0,4", "2,0,5", //
+ },
+ nested_row_lengths);
+
+ std::vector<std::string> expected_values = {
+ "0,0,0 0,0,1", "0,0,1 0,0,2", "0,0,2 0,0,3", "0,0,3 0,0,4", //
+ "0,1,0 0,1,1", "0,1,1 0,1,2", "0,1,2 0,1,3", //
+ "0,2,0 0,2,1", "0,2,1 0,2,2", //
+ "0,3,0 0,3,1", //
+ "1,0,0 1,0,1", //
+ "1,1,0 1,1,1", "1,1,1 1,1,2", //
+ "1,2,0 1,2,1", "1,2,1 1,2,2", "1,2,2 1,2,3", //
+ "2,0,0 2,0,1", "2,0,1 2,0,2", "2,0,2 2,0,3", "2,0,3 2,0,4",
+ "2,0,4 2,0,5", //
+ };
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(expected_values.size()));
+ EXPECT_THAT(m.ExtractValuesTensorVector(), ElementsAreArray(expected_values));
+ ASSERT_THAT(m.GetNumNestedRowLengths(), 2);
+ EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(3));
+ EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(4, 2, 1));
+ EXPECT_THAT(m.GetRowLengthsTensorShape(1), ElementsAre(8));
+ EXPECT_THAT(m.ExtractRowLengthsTensorVector(1),
+ ElementsAre(4, 3, 2, 1, 1, 2, 3, 5));
+}
+
+} // namespace test
+} // namespace ngrams
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py b/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py
new file mode 100644
index 00000000..e52ca285
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py
@@ -0,0 +1,266 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Lint as: python3
+"""Tests for tensorflow_lite_support.custom_ops.ngrams."""
+
+import os
+import sys
+import timeit
+
+from absl import logging
+from absl.testing import parameterized
+import tensorflow as tf
+import tensorflow_text as tf_text
+from tensorflow.lite.python import interpreter as interpreter_wrapper # pylint: disable=g-direct-tensorflow-import
+from tensorflow_lite_support.custom_ops.python import tflite_text_api
+
+# Force loaded shared object symbols to be globally visible. This is needed so
+# that the interpreter_wrapper, in one .so file, can see the op resolver
+# in a different .so file. Note that this may already be set by default.
+# pylint: disable=g-import-not-at-top,g-bad-import-order,unused-import
+if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
+ sys.setdlopenflags(sys.getdlopenflags() | os.RTLD_GLOBAL)
+from tensorflow_lite_support.custom_ops.kernel import _pywrap_ngrams_op_resolver
+
+TEST_CASES = [
+ [['this', 'is', 'a', 'test']],
+ [['one']],
+ [['two', 'tokens'], ['a', 'b']],
+ [['has', 'three', 'tokens'], ['a', 'b', 'c'], ['0', '1', '2']],
+ [['a', 'ragged', 'tensor'], ['a'], ['0', '1']],
+ [[['a', 'multidimensional', 'test', 'case'], ['a', 'b', 'c', 'd', 'e']],
+ [['0', '1', '2', '3', '4', '5']]],
+]
+
+INVOKES_FOR_SINGLE_OP_BENCHMARK = 1000
+INVOKES_FOR_FLEX_DELEGATE_BENCHMARK = 100
+
+
+class NgramsTest(parameterized.TestCase):
+
+ _models = {}
+
+ def _make_model(self, rank, width, ragged_tensor=False, flex=False):
+ temp_dir = self.create_tempdir().full_path
+
+ key = (rank, width, ragged_tensor, flex)
+ if key in self._models:
+ return self._models[key]
+
+ ngrams = tf_text.ngrams if flex else tflite_text_api.ngrams
+
+ if ragged_tensor:
+ input_signature = [tf.TensorSpec(shape=[None], dtype=tf.string)]
+ rs = rank - 1
+ input_signature += [tf.TensorSpec(shape=[None], dtype=tf.int64)] * rs
+
+ class Model(tf.Module):
+
+ @tf.function(input_signature=input_signature)
+ def __call__(self, values, *args):
+ row_splits = list(args)
+ row_splits.reverse()
+ input_tensor = tf.RaggedTensor.from_nested_row_splits(
+ flat_values=values, nested_row_splits=tuple(row_splits))
+ output_tensor = ngrams(
+ input_tensor, width, reduction_type=tf_text.Reduction.STRING_JOIN)
+ output = [output_tensor.flat_values]
+ output.extend(list(output_tensor.nested_row_splits))
+ output.reverse()
+ return tuple(output)
+
+ tf.saved_model.save(Model(), temp_dir)
+ else:
+ shape = [None] * rank
+
+ class Model(tf.Module):
+
+ @tf.function(
+ input_signature=[tf.TensorSpec(shape=shape, dtype=tf.string)])
+ def __call__(self, input_tensor):
+ return ngrams(
+ input_tensor, width, reduction_type=tf_text.Reduction.STRING_JOIN)
+
+ tf.saved_model.save(Model(), temp_dir)
+
+ converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir)
+ converter.inference_type = tf.float32
+ converter.inference_input_type = tf.float32
+ converter.allow_custom_ops = not flex
+ if flex:
+ converter.target_spec.supported_ops = [
+ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
+ ]
+ model = converter.convert()
+ self._models[key] = model
+ return model
+
+ @parameterized.parameters([t] for t in TEST_CASES)
+ def test_width_2_tensor_equivalence(self, test_case):
+ input_tensor = tf.ragged.constant(test_case).to_tensor()
+ tf_output = tf_text.ngrams(
+ input_tensor, 2, reduction_type=tf_text.Reduction.STRING_JOIN)
+
+ rank = input_tensor.shape.rank
+ model = self._make_model(rank, 2, ragged_tensor=False, flex=False)
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=model, custom_op_registerers=['AddNgramsCustomOp'])
+ interpreter.resize_tensor_input(0, input_tensor.shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ input_tensor.numpy())
+ interpreter.invoke()
+ tflite_output = interpreter.get_tensor(
+ interpreter.get_output_details()[0]['index'])
+
+ self.assertEqual(tf_output.numpy().tolist(), tflite_output.tolist())
+
+ @parameterized.parameters([t] for t in TEST_CASES)
+ def test_width_3_tensor_equivalence(self, test_case):
+ input_tensor = tf.ragged.constant(test_case).to_tensor()
+ tf_output = tf_text.ngrams(
+ input_tensor, 3, reduction_type=tf_text.Reduction.STRING_JOIN)
+
+ rank = input_tensor.shape.rank
+ model = self._make_model(rank, 3, ragged_tensor=False, flex=False)
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=model, custom_op_registerers=['AddNgramsCustomOp'])
+ interpreter.resize_tensor_input(0, input_tensor.shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ input_tensor.numpy())
+ interpreter.invoke()
+ tflite_output = interpreter.get_tensor(
+ interpreter.get_output_details()[0]['index'])
+ self.assertEqual(tf_output.numpy().tolist(), tflite_output.tolist())
+
+ @parameterized.parameters([t] for t in TEST_CASES)
+ def test_width_2_ragged_tensor_equivalence(self, test_case):
+ input_tensor = tf.ragged.constant(test_case)
+ tf_output = tf_text.ngrams(
+ input_tensor, 2, reduction_type=tf_text.Reduction.STRING_JOIN)
+
+ rank = input_tensor.shape.rank
+ model = self._make_model(rank, 2, ragged_tensor=True, flex=False)
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=model, custom_op_registerers=['AddNgramsCustomOp'])
+ interpreter.resize_tensor_input(0, input_tensor.flat_values.shape)
+ for r in range(rank - 1):
+ interpreter.resize_tensor_input(r + 1,
+ input_tensor.nested_row_splits[r].shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ input_tensor.flat_values.numpy())
+ for r in range(rank - 1):
+ interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'],
+ input_tensor.nested_row_splits[r].numpy())
+ interpreter.invoke()
+ tflite_output_values = interpreter.get_tensor(
+ interpreter.get_output_details()[0]['index'])
+ self.assertEqual(tf_output.flat_values.numpy().tolist(),
+ tflite_output_values.tolist())
+ for i in range(rank - 1):
+ tflite_output_cur_row_splits = interpreter.get_tensor(
+ interpreter.get_output_details()[i + 1]['index'])
+ self.assertEqual(tf_output.nested_row_splits[i].numpy().tolist(),
+ tflite_output_cur_row_splits.tolist())
+
+ @parameterized.parameters([t] for t in TEST_CASES)
+ def test_width_3_ragged_tensor_equivalence(self, test_case):
+ input_tensor = tf.ragged.constant(test_case)
+ tf_output = tf_text.ngrams(
+ input_tensor, 3, reduction_type=tf_text.Reduction.STRING_JOIN)
+
+ rank = input_tensor.shape.rank
+ model = self._make_model(rank, 3, ragged_tensor=True, flex=False)
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=model, custom_op_registerers=['AddNgramsCustomOp'])
+ interpreter.resize_tensor_input(0, input_tensor.flat_values.shape)
+ for r in range(rank - 1):
+ interpreter.resize_tensor_input(r + 1,
+ input_tensor.nested_row_splits[r].shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ input_tensor.flat_values.numpy())
+ for r in range(rank - 1):
+ interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'],
+ input_tensor.nested_row_splits[r].numpy())
+ interpreter.invoke()
+ tflite_output_values = interpreter.get_tensor(
+ interpreter.get_output_details()[0]['index'])
+ self.assertEqual(tf_output.flat_values.numpy().tolist(),
+ tflite_output_values.tolist())
+ for i in range(rank - 1):
+ tflite_output_cur_row_splits = interpreter.get_tensor(
+ interpreter.get_output_details()[i + 1]['index'])
+ self.assertEqual(tf_output.nested_row_splits[i].numpy().tolist(),
+ tflite_output_cur_row_splits.tolist())
+
+ def test_latency(self):
+ latency_op = 0.0
+ for test_case in TEST_CASES:
+ input_tensor = tf.ragged.constant(test_case)
+
+ rank = input_tensor.shape.rank
+ model = self._make_model(rank, 3, ragged_tensor=True, flex=False)
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=model, custom_op_registerers=['AddNgramsCustomOp'])
+ interpreter.resize_tensor_input(0, input_tensor.flat_values.shape)
+ for r in range(rank - 1):
+ interpreter.resize_tensor_input(r + 1,
+ input_tensor.nested_row_splits[r].shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ input_tensor.flat_values.numpy())
+ for r in range(rank - 1):
+ interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'],
+ input_tensor.nested_row_splits[r].numpy())
+ start_time = timeit.default_timer()
+ for _ in range(INVOKES_FOR_SINGLE_OP_BENCHMARK):
+ interpreter.invoke()
+ latency_op = latency_op + timeit.default_timer() - start_time
+ latency_op = latency_op / (
+ INVOKES_FOR_SINGLE_OP_BENCHMARK * len(TEST_CASES))
+
+ latency_flex = 0.0
+ for test_case in TEST_CASES:
+ input_tensor = tf.ragged.constant(test_case)
+
+ rank = input_tensor.shape.rank
+ model = self._make_model(rank, 3, ragged_tensor=True, flex=True)
+ interpreter = interpreter_wrapper.Interpreter(model_content=model)
+ interpreter.resize_tensor_input(0, input_tensor.flat_values.shape)
+ for r in range(rank - 1):
+ interpreter.resize_tensor_input(r + 1,
+ input_tensor.nested_row_splits[r].shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ input_tensor.flat_values.numpy())
+ for r in range(rank - 1):
+ interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'],
+ input_tensor.nested_row_splits[r].numpy())
+ start_time = timeit.default_timer()
+ for _ in range(INVOKES_FOR_FLEX_DELEGATE_BENCHMARK):
+ interpreter.invoke()
+ latency_flex = latency_flex + timeit.default_timer() - start_time
+ latency_flex = latency_flex / (
+ INVOKES_FOR_FLEX_DELEGATE_BENCHMARK * len(TEST_CASES))
+
+ logging.info('Latency (single op): %fms', latency_op * 1000.0)
+ logging.info('Latency (flex delegate): %fms', latency_flex * 1000.0)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD b/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD
new file mode 100644
index 00000000..a512cdc8
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD
@@ -0,0 +1,81 @@
+# RaggedTensors suppport in TFLite
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "ragged_tensor_to_tensor_tflite",
+ srcs = ["ragged_tensor_to_tensor_tflite.cc"],
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/core/util:ragged_to_dense_util_common",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:types",
+ ],
+)
+
+cc_test(
+ name = "ragged_tensor_to_tensor_tflite_test",
+ srcs = ["ragged_tensor_to_tensor_tflite_test.cc"],
+ deps = [
+ ":ragged_tensor_to_tensor_tflite",
+ "@com_google_googletest//:gtest_main",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:test_util",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
+
+cc_library(
+ name = "py_tflite_registerer",
+ srcs = ["py_tflite_registerer.cc"],
+ hdrs = ["py_tflite_registerer.h"],
+ deps = [
+ ":ragged_tensor_to_tensor_tflite",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "ragged_range_tflite",
+ srcs = ["ragged_range_tflite.cc"],
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:types",
+ ],
+)
+
+cc_test(
+ name = "ragged_range_tflite_test",
+ srcs = ["ragged_range_tflite_test.cc"],
+ deps = [
+ ":ragged_range_tflite",
+ "@com_google_googletest//:gtest_main",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:test_util",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ ],
+)
diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD b/tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD
new file mode 100644
index 00000000..650ab90b
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD
@@ -0,0 +1,27 @@
+# Python wrapper used for test.
+
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
+
+package(
+ default_visibility = [
+ "//tensorflow_lite_support:users",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+pybind_extension(
+ name = "pywrap_tflite_registerer",
+ srcs = [
+ "pywrap_tflite_registerer.cc",
+ ],
+ additional_exported_symbols = ["TFLite_RaggedTensorToTensorRegisterer"],
+ module_name = "pywrap_tflite_registerer",
+ srcs_version = "PY3ONLY",
+ deps = [
+ "//tensorflow_lite_support/custom_ops/kernel/ragged:py_tflite_registerer",
+ "@local_config_python//:python_headers",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ "@pybind11",
+ ],
+)
diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/py/pywrap_tflite_registerer.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/py/pywrap_tflite_registerer.cc
new file mode 100644
index 00000000..0b9432a9
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ragged/py/pywrap_tflite_registerer.cc
@@ -0,0 +1,35 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h"
+
+PYBIND11_MODULE(pywrap_tflite_registerer, m) {
+ m.doc() = R"pbdoc(
+ pywrap_tflite_registerer
+ A module with a wrapper that adds to a Python wrapper for TFLite
+ ragged_tensor_to_tensor.
+ )pbdoc";
+ m.def(
+ "TFLite_RaggedTensorToTensorRegisterer",
+ [](uintptr_t resolver) {
+ TFLite_RaggedTensorToTensorRegisterer(
+ reinterpret_cast<tflite::MutableOpResolver*>(resolver));
+ },
+ R"pbdoc(
+ The function that adds RaggedTensorToTensor to the TFLite interpreter.
+ )pbdoc");
+}
diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.cc
new file mode 100644
index 00000000..7c93d8b1
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.cc
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h"
+
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+extern "C" void TFLite_RaggedTensorToTensorRegisterer(
+ tflite::MutableOpResolver* resolver) {
+ resolver->AddCustom("RaggedTensorToTensor",
+ tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
+}
diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h b/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
new file mode 100644
index 00000000..ade3c5c1
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h
@@ -0,0 +1,25 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_
+
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+// C-function that is called from the Python Wrapper.
+
+extern "C" void TFLite_RaggedTensorToTensorRegisterer(
+ tflite::MutableOpResolver *resolver);
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
new file mode 100644
index 00000000..a35a6db9
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc
@@ -0,0 +1,192 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <functional>
+
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace ragged {
+namespace ragged_range {
+namespace {
+constexpr int kInputStarts = 0;
+constexpr int kInputLimits = 1;
+constexpr int kInputDeltas = 2;
+
+constexpr int kOutputNestedSplits = 0;
+constexpr int kOutputDenseValues = 1;
+
+TfLiteIntArray* IntArrayFromInt(int x) {
+ TfLiteIntArray* result = TfLiteIntArrayCreate(1);
+ result->data[0] = x;
+ return result;
+}
+
+// Returns the number of elements in the specified range.
+template <typename T, typename SPLITS_TYPE>
+SPLITS_TYPE RangeSize(T start, T limit, T delta) {
+ if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
+ return 0;
+ }
+ // The following is copied from tensorflow::RangeOp::Compute().
+ return (
+ std::is_integral<T>::value
+ ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
+ : std::ceil(std::abs((limit - start) / delta)));
+}
+
+template <typename T, typename SPLITS_TYPE>
+TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor& input_starts =
+ context->tensors[node->inputs->data[kInputStarts]];
+ TfLiteTensor& input_limits =
+ context->tensors[node->inputs->data[kInputLimits]];
+ TfLiteTensor& input_deltas =
+ context->tensors[node->inputs->data[kInputDeltas]];
+ // Determine which tensors we need to broadcast.
+ const bool broadcast_starts = NumElements(&input_starts) == 1;
+ const bool broadcast_limits = NumElements(&input_limits) == 1;
+ const bool broadcast_deltas = NumElements(&input_deltas) == 1;
+
+ // nrows (number of output rows) is the size of the non-broadcast inputs,
+ // or 1 if all inputs are scalars.
+ std::vector<int> in_sizes;
+ if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]);
+ if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]);
+ if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]);
+ if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes),
+ std::not_equal_to<>()) != std::end(in_sizes)) {
+ context->ReportError(
+ context,
+ "Invalid argument: starts, limits, and deltas must have the "
+ "same shape");
+ return kTfLiteError;
+ }
+
+ const SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes.front();
+
+ const T* starts = GetTensorData<T>(&input_starts);
+ const T* limits = GetTensorData<T>(&input_limits);
+ const T* deltas = GetTensorData<T>(&input_deltas);
+
+ TfLiteTensor& rt_nested_splits_out =
+ context->tensors[node->outputs->data[kOutputNestedSplits]];
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &rt_nested_splits_out,
+ IntArrayFromInt(nrows + 1)));
+ SPLITS_TYPE* rt_nested_splits =
+ GetTensorData<SPLITS_TYPE>(&rt_nested_splits_out);
+ rt_nested_splits[0] = 0;
+
+ for (int row = 0; row < nrows; ++row) {
+ const T start = broadcast_starts ? starts[0] : starts[row];
+ const T limit = broadcast_limits ? limits[0] : limits[row];
+ const T delta = broadcast_deltas ? deltas[0] : deltas[row];
+ if (delta == 0) {
+ context->ReportError(context, "Invalid argument: Requires delta != 0");
+ return kTfLiteError;
+ }
+ rt_nested_splits[row + 1] =
+ rt_nested_splits[row] + RangeSize<T, SPLITS_TYPE>(start, limit, delta);
+ }
+ const SPLITS_TYPE nvals = rt_nested_splits[nrows];
+
+ TfLiteTensor& rt_dense_values_out =
+ context->tensors[node->outputs->data[kOutputDenseValues]];
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &rt_dense_values_out,
+ IntArrayFromInt(nvals)));
+ T* rt_dense_values = GetTensorData<T>(&rt_dense_values_out);
+ int value_index = 0;
+ for (int row = 0; row < nrows; ++row) {
+ const SPLITS_TYPE row_size =
+ rt_nested_splits[row + 1] - rt_nested_splits[row];
+ T value = broadcast_starts ? starts[0] : starts[row];
+ const T delta = broadcast_deltas ? deltas[0] : deltas[row];
+ for (SPLITS_TYPE i = 0; i < row_size; ++i) {
+ rt_dense_values[value_index++] = value;
+ value += delta;
+ }
+ }
+ return kTfLiteOk;
+}
+
+template <typename SPLITS_TYPE>
+TfLiteStatus EvalSplitsT(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor& rt_dense_values_out =
+ context->tensors[node->outputs->data[kOutputDenseValues]];
+ switch (rt_dense_values_out.type) {
+ case kTfLiteInt32:
+ return EvalT<int32_t, SPLITS_TYPE>(context, node);
+ case kTfLiteInt64:
+ return EvalT<int64_t, SPLITS_TYPE>(context, node);
+ case kTfLiteFloat32:
+ return EvalT<float, SPLITS_TYPE>(context, node);
+ case kTfLiteFloat64:
+ return EvalT<double, SPLITS_TYPE>(context, node);
+ default:
+ context->ReportError(context,
+ "Invalid argument: Not supported VALUES type");
+ return kTfLiteError;
+ }
+}
+} // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Set outputs dynamic.
+ TfLiteTensor& nested_splits =
+ context->tensors[node->outputs->data[kOutputNestedSplits]];
+ SetTensorToDynamic(&nested_splits);
+ TfLiteTensor& dense_values =
+ context->tensors[node->outputs->data[kOutputDenseValues]];
+ SetTensorToDynamic(&dense_values);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor& rt_nested_splits_out =
+ context->tensors[node->outputs->data[kOutputNestedSplits]];
+ switch (rt_nested_splits_out.type) {
+ case kTfLiteInt32:
+ return EvalSplitsT<int32_t>(context, node);
+ case kTfLiteInt64:
+ return EvalSplitsT<int64_t>(context, node);
+ default:
+ context->ReportError(context,
+ "Invalid argument: Not supported ROW_SPLITS type");
+ return kTfLiteError;
+ }
+}
+
+} // namespace ragged_range
+} // namespace ragged
+TfLiteRegistration* Register_RAGGED_RANGE() {
+ static TfLiteRegistration r = {nullptr /*Initialize*/, nullptr /*Free*/,
+ ragged::ragged_range::Prepare,
+ ragged::ragged_range::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
new file mode 100644
index 00000000..54cf4459
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc
@@ -0,0 +1,155 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <initializer_list>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_RAGGED_RANGE();
+} // namespace custom
+} // namespace ops
+
+namespace {
+
+template <typename T>
+class RaggedRangeOpModel : public SingleOpModel {
+ public:
+ static TensorType GetType();
+
+ RaggedRangeOpModel(const std::vector<T>& start, const std::vector<T>& limits,
+ const std::vector<T>& deltas) {
+ const TensorType value_type = GetType();
+ std::vector<std::vector<int>> shapes;
+ input_start_ = AddInput(value_type);
+ shapes.push_back({static_cast<int>(start.size())});
+ input_limits_ = AddInput(value_type);
+ shapes.push_back({static_cast<int>(limits.size())});
+ input_deltas_ = AddInput(value_type);
+ shapes.push_back({static_cast<int>(deltas.size())});
+
+ output_splits_ = AddOutput(TensorType_INT32);
+ output_values_ = AddOutput(value_type);
+
+ SetCustomOp("RaggedRange", {}, ops::custom::Register_RAGGED_RANGE);
+ BuildInterpreter(shapes);
+
+ PopulateTensor(input_start_, start);
+ PopulateTensor(input_limits_, limits);
+ PopulateTensor(input_deltas_, deltas);
+ }
+
+ std::vector<int32> GetSplits() {
+ return ExtractVector<int32>(output_splits_);
+ }
+ std::vector<T> GetValues() const { return ExtractVector<T>(output_values_); }
+
+ protected:
+ int input_start_ = -1;
+ int input_limits_ = -1;
+ int input_deltas_ = -1;
+
+ int output_splits_ = -1;
+ int output_values_ = -1;
+};
+
+template <>
+TensorType RaggedRangeOpModel<int32>::GetType() {
+ return TensorType_INT32;
+}
+
+template <>
+TensorType RaggedRangeOpModel<float>::GetType() {
+ return TensorType_FLOAT32;
+}
+
+TEST(RaggedRangeOpTest, IntValues) {
+ RaggedRangeOpModel<int32> model({0, 5, 8, 5}, // Starts.
+ {8, 7, 8, 1}, // Limits.
+ {2, 1, 1, -1}); // Deltas.
+ model.Invoke();
+
+ EXPECT_THAT(model.GetSplits(),
+ testing::UnorderedElementsAreArray({0, 4, 6, 6, 10}));
+ EXPECT_THAT(model.GetValues(), testing::UnorderedElementsAreArray(
+ {0, 2, 4, 6, 5, 6, 5, 4, 3, 2}));
+}
+
+TEST(RaggedRangeOpTest, FloatValues) {
+ RaggedRangeOpModel<float> model({0, 5, 8, 5}, // Starts.
+ {8, 7, 8, 1}, // Limits.
+ {2, 1, 1, -1}); // Deltas.
+ model.Invoke();
+
+ EXPECT_THAT(model.GetSplits(),
+ testing::UnorderedElementsAreArray({0, 4, 6, 6, 10}));
+ EXPECT_THAT(model.GetValues(), testing::UnorderedElementsAreArray(
+ {0, 2, 4, 6, 5, 6, 5, 4, 3, 2}));
+}
+
+TEST(RaggedRangeOpTest, BroadcastDelta) {
+ RaggedRangeOpModel<int32> model({0, 5, 8}, // Starts.
+ {8, 7, 8}, // Limits.
+ {1}); // Deltas.
+ model.Invoke();
+
+ EXPECT_THAT(model.GetSplits(),
+ testing::UnorderedElementsAreArray({0, 8, 10, 10}));
+ EXPECT_THAT(model.GetValues(), testing::UnorderedElementsAreArray(
+ {0, 1, 2, 3, 4, 5, 6, 7, 5, 6}));
+}
+
+TEST(RaggedRangeOpTest, BroadcastStartDeltas) {
+ RaggedRangeOpModel<int32> model({0}, // Starts.
+ {10}, // Limits.
+ {2, 1}); // Deltas.
+ model.Invoke();
+
+ EXPECT_THAT(model.GetSplits(),
+ testing::UnorderedElementsAreArray({0, 5, 15}));
+ EXPECT_THAT(model.GetValues(),
+ testing::UnorderedElementsAreArray(
+ {0, 2, 4, 6, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}));
+}
+
+TEST(RaggedRangeOpTest, BadDeltas) {
+ RaggedRangeOpModel<int32> model({0, 5, 8, 5}, // Starts.
+ {8, 7, 7, 9}, // Limits.
+ {0, 1, 1, 1}); // Deltas.
+ EXPECT_EQ(model.InvokeUnchecked(), kTfLiteError);
+}
+
+TEST(RaggedRangeOpTest, ZeroRange) {
+ RaggedRangeOpModel<int32> model({0, 7}, // Starts.
+ {8, 5}, // Limits.
+ {1, 1}); // Deltas.
+ model.Invoke();
+ EXPECT_THAT(model.GetSplits(), testing::UnorderedElementsAreArray({0, 8, 8}));
+ EXPECT_THAT(model.GetValues(),
+ testing::UnorderedElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7}));
+}
+
+} // namespace
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
new file mode 100644
index 00000000..09ac76c7
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc
@@ -0,0 +1,690 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <memory>
+
+#include "flatbuffers/flexbuffers.h" // from @flatbuffers
+#include "tensorflow/core/util/ragged_to_dense_util_common.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace ragged {
+namespace ragged_tensor_to_tensor {
+namespace {
+
+constexpr int kShapeInput = 0;
+constexpr int kValuesInput = 1;
+constexpr int kDefaultValueInput = 2;
+constexpr int kFirstPartitionInputIndex = 3;
+
+constexpr int kOutputTensor = 0;
+
+constexpr char kRowPartitionTypesAttr[] = "row_partition_types";
+
+struct ConversionAttributes {
+ std::vector<tensorflow::RowPartitionType> partition_types;
+ int ragged_rank = 0;
+
+ tensorflow::RowPartitionType GetRowPartitionTypeByDimension(
+ int dimension) const {
+ if (partition_types.front() ==
+ tensorflow::RowPartitionType::FIRST_DIM_SIZE) {
+ return partition_types[dimension + 1];
+ } else {
+ return partition_types[dimension];
+ }
+ }
+};
+template <typename INDEX_TYPE>
+int GetFirstDimensionSizeT(TfLiteContext* context,
+ const TfLiteTensor& first_partition_input,
+ const ConversionAttributes* attributes) {
+ const tensorflow::RowPartitionType first_partition_type =
+ attributes->partition_types.front();
+ switch (first_partition_type) {
+ case tensorflow::RowPartitionType::FIRST_DIM_SIZE:
+ return *GetTensorData<INDEX_TYPE>(&first_partition_input);
+ case tensorflow::RowPartitionType::VALUE_ROWIDS:
+ context->ReportError(context,
+ "Cannot handle VALUE_ROWIDS in first dimension.");
+ return -1;
+ case tensorflow::RowPartitionType::ROW_SPLITS: {
+ const auto shape = GetTensorShape(&first_partition_input);
+ return shape.Dims(0) - 1;
+ }
+
+ default:
+ context->ReportError(
+ context, "Cannot handle type ",
+ RowPartitionTypeToString(first_partition_type).c_str());
+ return -1;
+ }
+}
+
+int GetFirstDimensionSize(TfLiteContext* context,
+ const TfLiteTensor& first_partition_input,
+ const ConversionAttributes* attributes) {
+ switch (first_partition_input.type) {
+ case kTfLiteInt32:
+ return GetFirstDimensionSizeT<int32_t>(context, first_partition_input,
+ attributes);
+ case kTfLiteInt64:
+ return GetFirstDimensionSizeT<int64_t>(context, first_partition_input,
+ attributes);
+ default:
+ context->ReportError(context,
+ "Not supported row partitioning tensor type");
+ return -1;
+ }
+}
+
+bool ValidateDefaultValueShape(TfLiteContext* context,
+ const RuntimeShape& default_value_shape,
+ const RuntimeShape& /*value_shape*/) {
+ // TF implementation also checks that shapes are not defined, not needed in
+ // TFLite.
+ // TODO(mgubin): Only scalar default value sizes are supported.
+ if (default_value_shape.FlatSize() != 1) {
+ context->ReportError(context, "Only scalar default value is supported");
+ return false;
+ }
+ return true;
+}
+
+RuntimeShape TensorShapeFromTensor(const TfLiteTensor& tensor) {
+ // TODO(mgubin): No checks, see
+ // third_party/tensorflow/core/kernels/list_kernels.cc
+ const RuntimeShape tensor_shape(tensor.dims->size, tensor.dims->data);
+ if (0 == tensor.dims->size) {
+ // If the input tensor is scalar then the shape is empty (also scalar).
+ return RuntimeShape{};
+ }
+ RuntimeShape result(tensor_shape.FlatSize());
+ switch (tensor.type) {
+ case kTfLiteInt32: {
+ for (int i = 0; i < tensor_shape.FlatSize(); ++i) {
+ result.SetDim(i, GetTensorData<int32_t>(&tensor)[i]);
+ }
+ } break;
+ case kTfLiteInt64: {
+ for (int i = 0; i < tensor_shape.FlatSize(); ++i) {
+ result.SetDim(i, GetTensorData<int64_t>(&tensor)[i]);
+ }
+ } break;
+ default: {
+ // Checked in Prepare.
+ }
+ }
+ return result;
+}
+
+const TfLiteTensor* GetRowPartitionTensor(
+ const ConversionAttributes& conversion_attributes, TfLiteContext* context,
+ TfLiteNode* node, int dimension) {
+ if (conversion_attributes.partition_types.front() ==
+ tensorflow::RowPartitionType::FIRST_DIM_SIZE) {
+ return &context->tensors[node->inputs->data[kFirstPartitionInputIndex + 1 +
+ dimension]];
+ } else {
+ return &context->tensors[node->inputs
+ ->data[kFirstPartitionInputIndex + dimension]];
+ }
+}
+
+int GetMaxWidthValueRowID(const TfLiteTensor* tensor) {
+ const RuntimeShape tensor_shape(tensor->dims->size, tensor->dims->data);
+ const int index_length = tensor_shape.FlatSize();
+ if (index_length == 0) {
+ return 0;
+ }
+ auto value_rowids = [tensor](int index) {
+ switch (tensor->type) {
+ case kTfLiteInt32:
+ return static_cast<int>(tensor->data.i32[index]);
+ case kTfLiteInt64:
+ return static_cast<int>(tensor->data.i64[index]);
+ default:
+ // TODO(mgubin): Add error checks.
+ return 0;
+ }
+ };
+ int first_equal_index = 0;
+ int first_equal_index_value = value_rowids(0);
+ int max_width = 0;
+ for (int i = 0; i < index_length; ++i) {
+ const int value = value_rowids(i);
+ if (value != first_equal_index_value) {
+ first_equal_index_value = value;
+ max_width = std::max(i - first_equal_index, max_width);
+ first_equal_index = i;
+ }
+ }
+ return std::max(index_length - first_equal_index, max_width);
+}
+
+int GetMaxWidthRowSplit(const TfLiteTensor* tensor) {
+ const RuntimeShape tensor_shape(tensor->dims->size, tensor->dims->data);
+ const int tensor_length = tensor_shape.FlatSize();
+ if (tensor_length == 0 || tensor_length == 1) {
+ return 0;
+ }
+ auto value_rowsplit = [tensor](int index) {
+ switch (tensor->type) {
+ case kTfLiteInt32:
+ return static_cast<int>(tensor->data.i32[index]);
+ case kTfLiteInt64:
+ return static_cast<int>(tensor->data.i64[index]);
+ default:
+ // TODO(mgubin): Add error checks.
+ return 0;
+ }
+ };
+ int max_width = 1;
+ int prev_split = value_rowsplit(0);
+ for (int i = 1; i < tensor_length; ++i) {
+ const int split = value_rowsplit(i);
+ max_width = std::max(max_width, split - prev_split);
+ prev_split = split;
+ }
+ return max_width;
+}
+
+int GetMaxWidth(const ConversionAttributes& conversion_attributes,
+ TfLiteContext* context, TfLiteNode* node, int dimension) {
+ const TfLiteTensor* tensor = GetRowPartitionTensor(
+ conversion_attributes, context, node, dimension - 1);
+ switch (conversion_attributes.GetRowPartitionTypeByDimension(dimension - 1)) {
+ case tensorflow::RowPartitionType::VALUE_ROWIDS:
+ return GetMaxWidthValueRowID(tensor);
+ case tensorflow::RowPartitionType::ROW_SPLITS:
+ return GetMaxWidthRowSplit(tensor);
+ default:
+ context->ReportError(context, "Cannot handle partition type");
+ return -1;
+ }
+}
+
+RuntimeShape CombineRaggedTensorToTensorShapes(
+ int ragged_rank, const RuntimeShape& output_shape,
+ const RuntimeShape& value_shape) {
+ // TODO(mgubin): No checks, see
+ // third_party/tensorflow/core/ops/ragged_to_dense_util.cc
+ RuntimeShape result(output_shape);
+ if (output_shape.DimensionsCount() == 0) {
+ const int output_shape_rank = ragged_rank + value_shape.DimensionsCount();
+ result.Resize(output_shape_rank);
+ for (int i = 0; i < output_shape_rank; ++i) {
+ result.SetDim(i, -1);
+ }
+ }
+ const int need_to_set =
+ output_shape.DimensionsCount() - value_shape.DimensionsCount();
+ for (int i = 1; i < value_shape.DimensionsCount(); ++i) {
+ result.SetDim(need_to_set + i, value_shape.Dims(i));
+ }
+ return result;
+}
+
+RuntimeShape CalculateOutputSize(
+ const ConversionAttributes& conversion_attributes, TfLiteContext* context,
+ TfLiteNode* node, int first_dimension, int ragged_rank,
+ const TfLiteTensor& values, const TfLiteTensor& default_value,
+ const TfLiteTensor& output_shape) {
+ RuntimeShape values_shape(values.dims->size, values.dims->data);
+ RuntimeShape default_value_shape(default_value.dims->size,
+ default_value.dims->data);
+
+ if (!ValidateDefaultValueShape(context, default_value_shape, values_shape)) {
+ return {};
+ }
+ RuntimeShape output_shape_shape = TensorShapeFromTensor(output_shape);
+
+ RuntimeShape result_shape = CombineRaggedTensorToTensorShapes(
+ ragged_rank, output_shape_shape, values_shape);
+ if (result_shape.Dims(0) < 0) {
+ result_shape.SetDim(0, first_dimension);
+ }
+ for (int i = 1; i <= ragged_rank; ++i) {
+ if (result_shape.Dims(i) < 0) {
+ result_shape.SetDim(i,
+ GetMaxWidth(conversion_attributes, context, node, i));
+ }
+ }
+ return result_shape;
+}
+
+TfLiteIntArray* IntArrayFromShape(const RuntimeShape& shape) {
+ TfLiteIntArray* result = TfLiteIntArrayCreate(shape.DimensionsCount());
+ for (int i = 0; i < shape.DimensionsCount(); ++i) {
+ result->data[i] = shape.Dims(i);
+ }
+ return result;
+}
+
+/**
+ * The output_index represents the index in the output tensor
+ * where the first element of a particular dimension would be written.
+ * If it is -1, it indicates that the index is out of scope.
+ * Example, given first_dimension = 10, first_dimension_output = 6,
+ * and output_index_multiplier = 100:
+ * result = [0 100 200 300 400 500 -1 -1 -1 -1]
+ * If first_dimension_output = 11 instead, then:
+ * result = [0 100 200 300 400 500 600 700 800 900]
+ */
+void CalculateFirstParentOutputIndex(int first_dimension,
+ int output_index_multiplier,
+ int first_dimension_output,
+ std::vector<int>* result) {
+ const int min_dimension = std::min(first_dimension, first_dimension_output);
+ result->reserve(first_dimension);
+ int current_output_index = 0;
+ for (int i = 0; i < min_dimension;
+ ++i, current_output_index += output_index_multiplier) {
+ result->push_back(current_output_index);
+ }
+ for (int i = min_dimension; i < first_dimension; ++i) {
+ result->push_back(-1);
+ }
+}
+// Calculate the output index of the first element of a list.
+// The parent_output_index is the same computation for the previous list.
+// -1 indicates an element or list that is out of range.
+// The output_index_multiplier is the number of output indices one moves
+// forward for each column.
+// E.g., given:
+// value_rowids:[0 1 2 2 2 3 5 5 6]
+// parent_output_index:[1000 1100 2000 2100 -1 3000 4000]
+// output_index_multiplier: 10
+// output_size: 2
+// You get:
+// result = [1000 1100 2000 2010 -1 2100 -1 -1 3000]
+// result[0] = parent_output_index[value_rowids[0]]
+// result[1] = parent_output_index[value_rowids[1]]
+// result[2] = parent_output_index[value_rowids[2]]
+// result[3] = parent_output_index[value_rowids[2] + 10]
+// result[4] = -1 because it is the third element the size is 2.
+// result[5] = parent_output_index[value_rowids[3]]
+// result[6] = -1 because parent_output_index[value_rowids[6]] == -1
+// result[7] = -1 because parent_output_index[value_rowids[6]] == -1
+// result[8] = parent_output_index[value_rowids[7]]
+void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids,
+ const std::vector<int>& parent_output_index,
+ int output_index_multiplier,
+ int output_size, std::vector<int>* result) {
+ const RuntimeShape tensor_shape(value_rowids.dims->size,
+ value_rowids.dims->data);
+ const int index_size = tensor_shape.FlatSize();
+ result->reserve(index_size);
+ if (index_size == 0) {
+ return;
+ }
+
+ auto value_rowids_val = [value_rowids](int index) {
+ switch (value_rowids.type) {
+ case kTfLiteInt32:
+ return static_cast<int>(value_rowids.data.i32[index]);
+ case kTfLiteInt64:
+ return static_cast<int>(value_rowids.data.i64[index]);
+ default:
+ // TODO(mgubin): Add error checks.
+ return 0;
+ }
+ };
+ int current_output_column = 0;
+ int current_value_rowid = value_rowids_val(0);
+ // DCHECK_LT(current_value_rowid, parent_output_index.size());
+ int current_output_index = parent_output_index[current_value_rowid];
+ result->push_back(current_output_index);
+ for (int i = 1; i < index_size; ++i) {
+ int next_value_rowid = value_rowids_val(i);
+ if (next_value_rowid == current_value_rowid) {
+ if (current_output_index >= 0) {
+ ++current_output_column;
+ if (current_output_column < output_size) {
+ current_output_index += output_index_multiplier;
+ } else {
+ current_output_index = -1;
+ }
+ }
+ } else {
+ current_output_column = 0;
+ current_value_rowid = next_value_rowid;
+ // DCHECK_LT(next_value_rowid, parent_output_index.size());
+ current_output_index = parent_output_index[next_value_rowid];
+ }
+ result->push_back(current_output_index);
+ }
+ // DCHECK_EQ(result->size(), value_rowids.size());
+}
+
+void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split,
+ const std::vector<int>& parent_output_index,
+ int output_index_multiplier, int output_size,
+ std::vector<int>* result) {
+ const RuntimeShape row_split_shape(row_split.dims->size,
+ row_split.dims->data);
+ const int row_split_size = row_split_shape.FlatSize();
+ auto row_split_val = [row_split](int index) {
+ switch (row_split.type) {
+ case kTfLiteInt32:
+ return static_cast<int>(row_split.data.i32[index]);
+ case kTfLiteInt64:
+ return static_cast<int>(row_split.data.i64[index]);
+ default:
+ // TODO(mgubin): Add error checks.
+ return 0;
+ }
+ };
+ if (row_split_size > 0) {
+ result->reserve(row_split_val(row_split_size - 1));
+ }
+ for (int i = 0; i < row_split_size - 1; ++i) {
+ const int row_length = row_split_val(i + 1) - row_split_val(i);
+ int real_length = std::min(output_size, row_length);
+ int parent_output_index_current = parent_output_index[i];
+
+ if (parent_output_index_current == -1) {
+ real_length = 0;
+ }
+ for (int j = 0; j < real_length; ++j) {
+ result->push_back(parent_output_index_current);
+ parent_output_index_current += output_index_multiplier;
+ }
+ for (int j = 0; j < row_length - real_length; ++j) {
+ result->push_back(-1);
+ }
+ }
+ // if (row_split_size > 0) {
+ // DCHECK_EQ(result->size(), row_split(row_split_size - 1));
+ //}
+}
+
+TfLiteStatus CalculateOutputIndex(
+ const ConversionAttributes& conversion_attributes, TfLiteContext* context,
+ TfLiteNode* node, int dimension,
+ const std::vector<int>& parent_output_index, int output_index_multiplier,
+ int output_size, std::vector<int>* result) {
+ const TfLiteTensor* row_partition_tensor =
+ GetRowPartitionTensor(conversion_attributes, context, node, dimension);
+ auto partition_type =
+ conversion_attributes.GetRowPartitionTypeByDimension(dimension);
+ switch (partition_type) {
+ case tensorflow::RowPartitionType::VALUE_ROWIDS:
+ CalculateOutputIndexValueRowID(*row_partition_tensor, parent_output_index,
+ output_index_multiplier, output_size,
+ result);
+ return kTfLiteOk;
+ case tensorflow::RowPartitionType::ROW_SPLITS:
+ CalculateOutputIndexRowSplit(*row_partition_tensor, parent_output_index,
+ output_index_multiplier, output_size,
+ result);
+ return kTfLiteOk;
+ default:
+ context->ReportError(context, "Unsupported partition type");
+ return kTfLiteError;
+ }
+}
+
+template <typename VALUE_TYPE>
+void SetOutputT(TfLiteContext* context, int ragged_rank,
+ const std::vector<int>& output_index,
+ const TfLiteTensor& values_tensor,
+ const TfLiteTensor& default_value_tensor,
+ TfLiteTensor* output_tensor) {
+ const VALUE_TYPE* values_base = GetTensorData<VALUE_TYPE>(&values_tensor);
+ VALUE_TYPE* output_base = GetTensorData<VALUE_TYPE>(output_tensor);
+ const VALUE_TYPE* default_value =
+ GetTensorData<VALUE_TYPE>(&default_value_tensor);
+
+ RuntimeShape output_shape = GetTensorShape(output_tensor);
+ RuntimeShape element_shape =
+ RuntimeShape(output_shape.DimensionsCount() - ragged_rank - 1,
+ output_shape.DimsData() + ragged_rank + 1);
+
+ // element_shape.RemoveDimRange(0, ragged_rank + 1);
+ const int value_element_size = element_shape.FlatSize();
+ size_t output_index_size = output_index.size();
+
+ // Loop through the output_index vector, finding contiguous regions that
+ // should be copied. Once we find the end of a contiguous region, copy it
+ // and add any necessary padding (with default_value).
+ int src_start = 0; // Start of contiguous region (in values)
+ int dst_start = 0; // Destination for contiguous region (in output)
+ int dst_end = 0; // Destination for contiguous region (in output)
+ for (int src_i = 0; src_i <= output_index_size; ++src_i) {
+ // dst_i is the destination where the value at src_i should be copied.
+ int dst_i = src_i < output_index_size ? output_index[src_i] : -1;
+
+ // If we're still in a contiguous region, then update dst_end go to the
+ // next src_i.
+ if (dst_i == dst_end) {
+ ++dst_end;
+ continue;
+ }
+
+ // We found the end of contiguous region. This can be because we found
+ // a gap (dst_i > dst_end), or a source value that shouldn't be copied
+ // because it's out-of-bounds (dst_i == -1), or the end of the tensor
+ // (dst_i = -1).
+ if (dst_start < dst_end) {
+ // Copy the contiguous region.
+ const VALUE_TYPE* src = values_base + src_start * value_element_size;
+ VALUE_TYPE* dst = output_base + dst_start * value_element_size;
+ int nvals = (dst_end - dst_start) * value_element_size;
+ std::copy(src, src + nvals, dst);
+ // copy_array<VALUE_TYPE, int>(dst, src, nvals);
+ }
+
+ // Add any necessary padding (w/ default_value).
+ if (src_i >= output_index_size) {
+ // We reached the end of values: pad to the end of output.
+ const int output_size = output_shape.FlatSize();
+ dst_i = output_size / value_element_size;
+ }
+ if (dst_i > dst_end) {
+ std::fill(output_base + dst_end * value_element_size,
+ output_base + dst_i * value_element_size, *default_value);
+ dst_end = dst_i;
+ }
+
+ // Update indices.
+ if (dst_i < 0) {
+ // src_i should be skipped -- leave it out of the contiguous region.
+ src_start = src_i + 1;
+ dst_start = dst_end;
+ } else {
+ // src_i should be copied -- include it in the contiguous region.
+ src_start = src_i;
+ dst_start = dst_end;
+ dst_end = dst_start + 1;
+ }
+ }
+}
+
+void SetOutput(TfLiteContext* context, int ragged_rank,
+ const std::vector<int>& output_index,
+ const TfLiteTensor& values_tensor,
+ const TfLiteTensor& default_value_tensor,
+ TfLiteTensor* output_tensor) {
+ switch (output_tensor->type) {
+ case kTfLiteInt32:
+ SetOutputT<int32_t>(context, ragged_rank, output_index, values_tensor,
+ default_value_tensor, output_tensor);
+ break;
+ case kTfLiteInt64:
+ SetOutputT<int64_t>(context, ragged_rank, output_index, values_tensor,
+ default_value_tensor, output_tensor);
+ break;
+ case kTfLiteFloat32:
+ SetOutputT<float>(context, ragged_rank, output_index, values_tensor,
+ default_value_tensor, output_tensor);
+ break;
+ default:
+ context->ReportError(context, "Not supported values type");
+ }
+}
+
+} // namespace
+
+void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
+ auto attributes = std::make_unique<ConversionAttributes>();
+
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+
+ const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+ // TODO (mgubin): Converting flat buffer to a vector of strings looks not very
+ // effective but simple. A cleaner way is needed.
+ const flexbuffers::TypedVector row_partition_types_attr =
+ m[kRowPartitionTypesAttr].AsTypedVector();
+ std::vector<std::string> row_partition_types_attr_strings;
+ row_partition_types_attr_strings.reserve(row_partition_types_attr.size());
+ for (int i = 0; i < row_partition_types_attr.size(); ++i) {
+ row_partition_types_attr_strings.emplace_back(
+ row_partition_types_attr[i].AsString().str());
+ }
+ attributes->partition_types =
+ tensorflow::GetRowPartitionTypesHelper(row_partition_types_attr_strings);
+ if (attributes->partition_types.size() !=
+ row_partition_types_attr_strings.size()) {
+ context->ReportError(context, "Can't parse partition type attribute");
+ return nullptr;
+ }
+ attributes->ragged_rank =
+ tensorflow::GetRaggedRank(attributes->partition_types);
+ return attributes.release();
+}
+void Free(TfLiteContext* /*context*/, void* buffer) {
+ ConversionAttributes* attributes =
+ reinterpret_cast<ConversionAttributes*>(buffer);
+ delete attributes;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const ConversionAttributes* attributes =
+ reinterpret_cast<ConversionAttributes*>(node->user_data);
+ if (attributes == nullptr) {
+ // Parsing attributes failed, can't prepare.
+ context->ReportError(context, "Attributes are not initialized");
+ return kTfLiteError;
+ }
+ // The output tensor need to be set to dynamic because it can have different
+ // size.
+ TfLiteTensor& output_tensor =
+ context->tensors[node->outputs->data[kOutputTensor]];
+ SetTensorToDynamic(&output_tensor);
+
+ // Check that input shape tensor is int32 or int64
+ TfLiteTensor& input_shape = context->tensors[node->inputs->data[kShapeInput]];
+ if (input_shape.type != kTfLiteInt32 && input_shape.type != kTfLiteInt64) {
+ context->ReportError(context,
+ "Input form tensor could be only int32 or int64");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const ConversionAttributes* attributes =
+ reinterpret_cast<ConversionAttributes*>(node->user_data);
+ TfLiteTensor& input_shape = context->tensors[node->inputs->data[kShapeInput]];
+ TfLiteTensor& input_values =
+ context->tensors[node->inputs->data[kValuesInput]];
+ TfLiteTensor& default_value =
+ context->tensors[node->inputs->data[kDefaultValueInput]];
+ // TODO (mgubin): Only scallar default value is supported.
+ if (RuntimeShape(default_value.dims->size, default_value.dims->data)
+ .FlatSize() != 1) {
+ context->ReportError(context, "Only scallar default value is supported");
+ return kTfLiteError;
+ }
+ TfLiteTensor& first_partition_input =
+ context->tensors[node->inputs->data[kFirstPartitionInputIndex]];
+
+ // Calculate dimensions.
+ const int first_dimension =
+ GetFirstDimensionSize(context, first_partition_input, attributes);
+ if (first_dimension < 0) {
+ return kTfLiteError;
+ }
+ RuntimeShape output_shape = CalculateOutputSize(
+ *attributes, context, node, first_dimension, attributes->ragged_rank,
+ input_values, default_value, input_shape);
+ if (output_shape.DimensionsCount() == 0) {
+ return kTfLiteError;
+ }
+
+ std::vector<int> multiplier;
+ multiplier.resize(attributes->ragged_rank + 1);
+ multiplier.back() = 1;
+ for (int i = multiplier.size() - 2; i >= 0; --i) {
+ multiplier[i] = multiplier[i + 1] * output_shape.Dims(i + 1);
+ }
+
+ // Allocate output tensor.
+ TfLiteTensor& output_tensor =
+ context->tensors[node->outputs->data[kOutputTensor]];
+
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &output_tensor,
+ IntArrayFromShape(output_shape)));
+
+ // Copy data.
+ const int full_size = multiplier.front() * output_shape.Dims(0);
+ if (full_size > 0) {
+ std::vector<int> output_index, new_output_index;
+ int nvals = input_values.dims->data[0];
+ output_index.reserve(nvals);
+ new_output_index.reserve(nvals);
+
+ CalculateFirstParentOutputIndex(first_dimension, multiplier[0],
+ output_shape.Dims(0), &output_index);
+ for (int i = 1; i <= attributes->ragged_rank; ++i) {
+ TF_LITE_ENSURE_OK(
+ context, CalculateOutputIndex(
+ *attributes, context, node, i - 1, output_index,
+ multiplier[i], output_shape.Dims(i), &new_output_index));
+ output_index.swap(new_output_index);
+ new_output_index.clear();
+ }
+
+ SetOutput(context, attributes->ragged_rank, output_index, input_values,
+ default_value, &output_tensor);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace ragged_tensor_to_tensor
+} // namespace ragged
+
+TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR() {
+ static TfLiteRegistration r = {ragged::ragged_tensor_to_tensor::Initialize,
+ ragged::ragged_tensor_to_tensor::Free,
+ ragged::ragged_tensor_to_tensor::Prepare,
+ ragged::ragged_tensor_to_tensor::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
new file mode 100644
index 00000000..b1cde57c
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc
@@ -0,0 +1,283 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
+} // namespace custom
+} // namespace ops
+
+namespace {
+
+class RaggedTensorToTensorOpModel : public SingleOpModel {
+ public:
+ RaggedTensorToTensorOpModel(int output_shape_dims,
+ std::initializer_list<int> values_shape,
+ std::initializer_list<std::initializer_list<int>>
+ partition_tensors_shapes,
+ std::vector<std::string> partition_types,
+ TensorType value_type = TensorType_FLOAT32,
+ TensorType index_type = TensorType_INT32) {
+ // A structure to collect shapes for the input.
+ std::vector<std::vector<int>> shapes;
+ input_shape_ = AddInput(index_type);
+ shapes.push_back({output_shape_dims});
+ input_values_ = AddInput(value_type);
+ shapes.emplace_back(values_shape);
+ input_default_values_ = AddInput(value_type);
+ shapes.push_back({1});
+ for (const auto& p : partition_tensors_shapes) {
+ partition_tensors_.push_back(AddInput(TensorType_INT32));
+ shapes.emplace_back(p);
+ }
+ output_ = AddOutput(value_type);
+
+ flexbuffers::Builder fbb;
+ size_t start = fbb.StartMap();
+ {
+ size_t start = fbb.StartVector("row_partition_types");
+ for (const auto& s : partition_types) {
+ fbb.String(s);
+ }
+ fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
+ }
+ fbb.Int("num_row_partition_tensors", partition_types.size());
+ fbb.EndMap(start);
+ fbb.Finish();
+ SetCustomOp("RaggedTensorToTensor", fbb.GetBuffer(),
+ ops::custom::Register_RAGGED_TENSOR_TO_TENSOR);
+ BuildInterpreter(shapes);
+ }
+
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); }
+ std::vector<int32> GetOutputInt() { return ExtractVector<int32>(output_); }
+
+ void InvokeFloat(const std::vector<int>& shape,
+ const std::vector<float>& values, float default_value,
+ const std::vector<std::vector<int>>& partition_values) {
+ PopulateTensor(input_shape_, shape);
+ PopulateTensor(input_values_, values);
+ PopulateTensor(input_default_values_, {default_value});
+ for (int i = 0; i < partition_values.size(); ++i) {
+ PopulateTensor(partition_tensors_[i], partition_values[i]);
+ }
+ SingleOpModel::Invoke();
+ }
+ void InvokeInt(const std::vector<int>& shape,
+ const std::vector<int32>& values, int32 default_value,
+ const std::vector<std::vector<int>>& partition_values) {
+ PopulateTensor(input_shape_, shape);
+ PopulateTensor(input_values_, values);
+ PopulateTensor(input_default_values_, {default_value});
+ for (int i = 0; i < partition_values.size(); ++i) {
+ PopulateTensor(partition_tensors_[i], partition_values[i]);
+ }
+ SingleOpModel::Invoke();
+ }
+
+ private:
+ int input_shape_;
+ int input_values_;
+ int input_default_values_;
+ std::vector<int> partition_tensors_;
+ int output_;
+};
+
+TEST(RaggedTensorToTensorTest, RaggedTensorToTensor) {
+ // indices = [2, 1, 0, 3]
+ // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+ // params.shape = [4, None]
+ RaggedTensorToTensorOpModel model(
+ 2, // output_shape_dims
+ {9}, // values_shape
+ {{1}, {9}}, // partition_tensors_shapes
+ std::vector<std::string>({"FIRST_DIM_SIZE", "VALUE_ROWIDS"}));
+ model.InvokeFloat({4, 4}, // shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values
+ 1.5, // default_value
+ std::vector<std::vector<int>>(
+ {std::vector<int>({4}),
+ std::vector<int>({0, 0, 0, 2, 2, 2, 2, 3, 3})}));
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({4, 4}));
+ EXPECT_THAT(model.GetOutputFloat(),
+ testing::ElementsAreArray({.1, .2, .3, 1.5, 1.5, 1.5, 1.5, 1.5,
+ .4, .5, .6, .7, .8, .9, 1.5, 1.5}));
+}
+
+TEST(RaggedTensorToTensorTest, RaggedTensorToTensorRowSplits) {
+ // indices = [2, 1, 0, 3]
+ // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+ RaggedTensorToTensorOpModel model(2, // output_shape_dims
+ {9}, // values_shape
+ {{5}}, // partition_tensors_shapes
+ std::vector<std::string>({"ROW_SPLITS"}));
+ model.InvokeFloat(
+ {4, 4}, // shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values
+ 1.5, // default_value
+ std::vector<std::vector<int>>({std::vector<int>({0, 3, 3, 7, 9})}));
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({4, 4}));
+ EXPECT_THAT(model.GetOutputFloat(),
+ testing::ElementsAreArray({.1, .2, .3, 1.5, 1.5, 1.5, 1.5, 1.5,
+ .4, .5, .6, .7, .8, .9, 1.5, 1.5}));
+}
+
+TEST(RaggedTensorToTensorTest, RaggedTensorToTensor_3DParams) {
+ // params = [
+ // [[]],
+ // [[.1, .2], [.3]],
+ // [],
+ // [[.4, .5], [.6, .7, .8]],
+ // [[.9]]
+ // ]
+ RaggedTensorToTensorOpModel model(
+ 3, // output_shape_dims
+ {9}, // values_shape
+ {{1}, {6}, {9}}, // partition_tensors_shapes
+ std::vector<std::string>(
+ {"FIRST_DIM_SIZE", "VALUE_ROWIDS", "VALUE_ROWIDS"}));
+ model.InvokeFloat(
+ {5, 2, 3}, // shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values
+ 1.5, // default_value
+ std::vector<std::vector<int>>(
+ {std::vector<int>({5}), std::vector<int>({0, 1, 1, 3, 3, 4}),
+ std::vector<int>({1, 1, 2, 3, 3, 4, 4, 4, 5})}));
+
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({5, 2, 3}));
+ EXPECT_THAT(model.GetOutputFloat(),
+ testing::ElementsAreArray({1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .1, .2,
+ 1.5, .3, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
+ 1.5, 1.5, .4, .5, 1.5, .6, .7, .8,
+ .9, 1.5, 1.5, 1.5, 1.5, 1.5}));
+}
+
+TEST(RaggedTensorToTensorOpTest, RaggedTensorToTensor_3DParamsRowSplits) {
+ // params = [
+ // [[]],
+ // [[.1, .2], [.3]],
+ // [],
+ // [[.4, .5], [.6, .7, .8]],
+ // [[.9]]
+ // ]
+ RaggedTensorToTensorOpModel model(
+ 3, // output_shape_dims
+ {9}, // values_shape
+ {{6}, {7}}, // partition_tensors_shapes
+ std::vector<std::string>({"ROW_SPLITS", "ROW_SPLITS"}));
+ model.InvokeFloat(
+ {5, 2, 3}, // shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values
+ 1.5, // default_value
+ std::vector<std::vector<int>>({std::vector<int>({0, 1, 3, 3, 5, 6}),
+ std::vector<int>({0, 0, 2, 3, 5, 8, 9})}));
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({5, 2, 3}));
+ EXPECT_THAT(model.GetOutputFloat(),
+ testing::ElementsAreArray({1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .1, .2,
+ 1.5, .3, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
+ 1.5, 1.5, .4, .5, 1.5, .6, .7, .8,
+ .9, 1.5, 1.5, 1.5, 1.5, 1.5}));
+}
+
+TEST(RaggedTensorToTensorTest, RaggedTensorToTensor_3DParamsRowSplits2) {
+ // params = [
+ // [[0, 1, 2], []],
+ // [],
+ // [[3]]
+ // ]
+
+ RaggedTensorToTensorOpModel model(
+ 3, // output_shape_dims
+ {4}, // values_shape
+ {{4}, {4}}, // partition_tensors_shapes
+ std::vector<std::string>({"ROW_SPLITS", "ROW_SPLITS"}), TensorType_INT32);
+ model.InvokeInt(
+ {3, 2, 3}, // shape
+ {0, 1, 2, 3}, // values
+ 5, // default_value
+ std::vector<std::vector<int>>(
+ {std::vector<int>({0, 2, 2, 3}), std::vector<int>({0, 3, 3, 4})}));
+
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 2, 3}));
+
+ EXPECT_THAT(model.GetOutputInt(),
+ testing::ElementsAreArray(
+ {0, 1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 5, 5, 5, 5, 5}));
+}
+
+TEST(RaggedTensorToTensorTest, RaggedTensorToTensorContractExpanded) {
+ // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+ RaggedTensorToTensorOpModel model(
+ 2, // output_shape_dims
+ {9}, // values_shape
+ {{1}, {9}}, // partition_tensors_shapes
+ std::vector<std::string>({"FIRST_DIM_SIZE", "VALUE_ROWIDS"}));
+ model.InvokeFloat({3, 5}, // shape
+ {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values
+ 1.5, // default_value
+ std::vector<std::vector<int>>(
+ {std::vector<int>({4}),
+ std::vector<int>({0, 0, 0, 2, 2, 2, 2, 3, 3})}));
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 5}));
+
+ EXPECT_THAT(model.GetOutputFloat(),
+ testing::ElementsAreArray({.1, .2, .3, 1.5, 1.5, //
+ 1.5, 1.5, 1.5, 1.5, 1.5, //
+ .4, .5, .6, .7, 1.5}));
+}
+
+// Adds a dense dimension.
+TEST(RaggedTensorToTensorTest, RaggedTensorToTensorContractExpandedDense) {
+ // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]]
+ RaggedTensorToTensorOpModel model(
+ 3, // output_shape_dims
+ {9, 2}, // values_shape
+ {{1}, {9}}, // partition_tensors_shapes
+ std::vector<std::string>({"FIRST_DIM_SIZE", "VALUE_ROWIDS"}));
+
+ model.InvokeFloat({3, 5, 2}, // shape
+ {.1, 1.1, .2, 1.2, .3, 1.3, .4, 1.4, .5, 1.5, .6, 1.6, .7,
+ 1.7, .8, 1.8, .9, 1.9}, // values
+ 1.5, // default_value
+ std::vector<std::vector<int>>(
+ {std::vector<int>({4}),
+ std::vector<int>({0, 0, 0, 2, 2, 2, 2, 3, 3})}));
+
+ EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 5, 2}));
+ EXPECT_THAT(model.GetOutputFloat(),
+ testing::ElementsAreArray(
+ {.1, 1.1, .2, 1.2, .3, 1.3, 1.5, 1.5, 1.5, 1.5, //
+ 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, //
+ .4, 1.4, .5, 1.5, .6, 1.6, .7, 1.7, 1.5, 1.5}));
+}
+} // namespace
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD
new file mode 100644
index 00000000..e8df50f3
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD
@@ -0,0 +1,389 @@
+# Memorymappable, WASM compilable, implementation of the encoder.
+#
+
+load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+load(":native.bzl", "micore_tf_copts", "micore_tf_deps")
+
+package(
+ default_visibility = [
+ "//tensorflow_lite_support:users",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+filegroup(
+ name = "testdata",
+ srcs = glob([
+ "testdata/**",
+ ]),
+)
+
+filegroup(
+ name = "config_fbs",
+ srcs = ["config.fbs"],
+)
+
+flatbuffer_cc_library(
+ name = "config",
+ srcs = [
+ "config.fbs",
+ ],
+)
+
+flatbuffer_cc_library(
+ name = "encoder_config",
+ srcs = [
+ "encoder_config.fbs",
+ ],
+ includes = [":config_fbs"],
+)
+
+flatbuffer_cc_library(
+ name = "decoder_config",
+ srcs = [
+ "decoder_config.fbs",
+ ],
+ includes = [":config_fbs"],
+)
+
+cc_library(
+ name = "utils",
+ srcs = [
+ ],
+ hdrs = [
+ "utils.h",
+ ],
+)
+
+cc_library(
+ name = "double_array_trie",
+ srcs = [
+ ],
+ hdrs = [
+ "double_array_trie.h",
+ ],
+ deps = [
+ ":config",
+ ":utils",
+ ],
+)
+
+cc_library(
+ name = "double_array_trie_builder",
+ srcs = [
+ "double_array_trie_builder.cc",
+ ],
+ hdrs = [
+ "double_array_trie_builder.h",
+ ],
+ deps = [
+ ":config",
+ ":utils",
+ "@darts_clone",
+ ],
+)
+
+cc_test(
+ name = "double_array_trie_test",
+ srcs = [
+ "double_array_trie_test.cc",
+ ],
+ deps = [
+ ":double_array_trie",
+ ":double_array_trie_builder",
+ ":encoder_config",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "sentencepiece_constants",
+ srcs = [],
+ hdrs = ["sentencepiece_constants.h"],
+)
+
+cc_library(
+ name = "model_converter",
+ srcs = [
+ "model_converter.cc",
+ ],
+ hdrs = [
+ "model_converter.h",
+ ],
+ deps = [
+ ":config",
+ ":decoder_config",
+ ":double_array_trie_builder",
+ ":encoder_config",
+ ":sentencepiece_constants",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
+ ],
+)
+
+cc_library(
+ name = "optimized_encoder",
+ srcs = [
+ "optimized_encoder.cc",
+ ],
+ hdrs = [
+ "optimized_encoder.h",
+ ],
+ deps = [
+ ":config",
+ ":double_array_trie",
+ ":encoder_config",
+ ],
+)
+
+cc_library(
+ name = "optimized_decoder",
+ srcs = [
+ "optimized_decoder.cc",
+ ],
+ hdrs = [
+ "optimized_decoder.h",
+ ],
+ deps = [
+ "config",
+ ":decoder_config",
+ ":double_array_trie",
+ ],
+)
+
+cc_library(
+ name = "sentencepiece_tokenizer_h",
+ hdrs = [
+ "sentencepiece_tokenizer.h",
+ ],
+)
+
+cc_library(
+ name = "sentencepiece_detokenizer_h",
+ hdrs = [
+ "sentencepiece_detokenizer.h",
+ ],
+)
+
+cc_library(
+ name = "sentencepiece_tokenizer_op",
+ srcs = ["sentencepiece_tokenizer_op.cc"],
+ copts = micore_tf_copts(),
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ ":sentencepiece_tokenizer_h",
+ ":optimized_encoder",
+ ] + micore_tf_deps(),
+ alwayslink = 1,
+)
+
+cc_binary(
+ name = "sentencepiece_tokenizer_op.so",
+ srcs = [
+ "sentencepiece_tokenizer_op.cc",
+ ],
+ copts = micore_tf_copts(),
+ linkshared = 1,
+ deps = [
+ ":sentencepiece_tokenizer_h",
+ ":optimized_encoder",
+ ] + micore_tf_deps(),
+)
+
+cc_library(
+ name = "sentencepiece_detokenizer_op",
+ srcs = ["sentencepiece_detokenizer_op.cc"],
+ copts = micore_tf_copts(),
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ ":sentencepiece_detokenizer_h",
+ ":optimized_decoder",
+ ] + micore_tf_deps(),
+ alwayslink = 1,
+)
+
+cc_binary(
+ name = "sentencepiece_detokenizer_op.so",
+ srcs = [
+ "sentencepiece_detokenizer_op.cc",
+ ],
+ copts = micore_tf_copts(),
+ linkshared = 1,
+ deps = [
+ ":sentencepiece_detokenizer_h",
+ ":optimized_decoder",
+ ] + micore_tf_deps(),
+)
+
+cc_library(
+ name = "sentencepiece_tokenizer_tflite",
+ srcs = ["sentencepiece_tokenizer_tflite.cc"],
+ visibility = [
+ "//visibility:public",
+ ],
+ deps =
+ [
+ ":optimized_encoder",
+ ":sentencepiece_tokenizer_h",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite:string_util",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
+ ],
+)
+
+cc_library(
+ name = "sentencepiece_detokenizer_tflite",
+ srcs = ["sentencepiece_detokenizer_tflite.cc"],
+ visibility = [
+ "//visibility:public",
+ ],
+ deps =
+ [
+ ":optimized_decoder",
+ ":sentencepiece_detokenizer_h",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite:string_util",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
+ ],
+)
+
+cc_test(
+ name = "optimized_encoder_test",
+ srcs = [
+ "optimized_encoder_test.cc",
+ ],
+ data = [
+ ":testdata",
+ ],
+ deps = [
+ ":double_array_trie_builder",
+ ":encoder_config",
+ ":model_converter",
+ ":optimized_encoder",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_sentencepiece//src:sentencepiece_cc_proto",
+ "@com_google_sentencepiece//src:sentencepiece_processor",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "optimized_decoder_test",
+ srcs = [
+ "optimized_decoder_test.cc",
+ ],
+ data = [
+ ":testdata",
+ ],
+ deps = [
+ ":decoder_config",
+ ":double_array_trie_builder",
+ ":model_converter",
+ ":optimized_decoder",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_sentencepiece//src:sentencepiece_cc_proto",
+ "@com_google_sentencepiece//src:sentencepiece_processor",
+ "@org_tensorflow//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "py_tflite_registerer",
+ srcs = ["py_tflite_registerer.cc"],
+ hdrs = ["py_tflite_registerer.h"],
+ deps = [
+ ":sentencepiece_detokenizer_tflite",
+ ":sentencepiece_tokenizer_tflite",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ ],
+ alwayslink = 1,
+)
+
+config_setting(
+ name = "armeabi_v7a_and_fastbuild",
+ values = {
+ "cpu": "armeabi-v7a",
+ "compilation_mode": "fastbuild",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "armeabi_v7a_and_dbg",
+ values = {
+ "cpu": "armeabi-v7a",
+ "compilation_mode": "dbg",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "android",
+ values = {"crosstool_top": "//external:android/crosstool"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "macos_i386",
+ values = {
+ "apple_platform_type": "macos",
+ "cpu": "darwin",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "macos_x86_64",
+ values = {
+ "apple_platform_type": "macos",
+ "cpu": "darwin_x86_64",
+ },
+ visibility = ["//visibility:public"],
+)
+
+alias(
+ name = "macos",
+ actual = select({
+ ":macos_i386": ":macos_i386",
+ ":macos_x86_64": ":macos_x86_64",
+ "//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
+ }),
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "ios",
+ values = {
+ "crosstool_top": "@bazel_tools//tools/cpp:toolchain",
+ "apple_platform_type": "ios",
+ },
+ visibility = ["//visibility:public"],
+)
+
+alias(
+ name = "apple",
+ actual = select({
+ ":macos": ":macos",
+ ":ios": ":ios",
+ "//conditions:default": ":ios", # Arbitrarily chosen from above.
+ }),
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/config.fbs b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/config.fbs
new file mode 100644
index 00000000..eba0bd8a
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/config.fbs
@@ -0,0 +1,25 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+namespace tflite.ops.custom.sentencepiece;
+
+table Trie {
+ nodes: [uint32];
+}
+
+
+enum EncoderVersion: byte {
+ SENTENCE_PIECE = 0,
+}
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config.fbs b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config.fbs
new file mode 100644
index 00000000..4a230ed9
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config.fbs
@@ -0,0 +1,43 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+include "config.fbs";
+
+namespace tflite.ops.custom.sentencepiece;
+
+
+table DecoderConfig {
+ version: EncoderVersion = SENTENCE_PIECE;
+
+ // The offset for encoding, usually used when codes with low codes are reserved
+ // for some special needs.
+ encoding_offset: int32;
+
+ // A vector of strings that represent sentencepieces.
+ decode_pieces: [string];
+
+ // TODO(mgubin): Currently is not populated, haven't seen any Sentencepiece
+ // model with a denormalizer.
+ denormalized_prefixes: Trie;
+ denormalized_replacements: [byte];
+
+ // During encoding a dummy prefix (a whitespace) can be added to the input string,
+ // if this flag is true, this prefix will be removed.
+ remove_dummy_prefix: bool;
+
+}
+
+
+root_type DecoderConfig;
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h
new file mode 100644
index 00000000..547e0ea8
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h
@@ -0,0 +1,120 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
+
+#include <functional>
+#include <vector>
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/config_generated.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+// A trie node specifies a node in the tree, either an intermediate node or
+// a leaf node.
+// A leaf node contains the id as an int of the string match. This id is encoded
+// in the lower 31 bits, thus the number of distinct ids is 2^31.
+// An intermediate node has an associated label and an offset to its children.
+// The label is encoded in the least significant byte and must match the input
+// character during matching.
+
+// A memory mappable trie, compatible with Darts::DoubleArray.
+class DoubleArrayTrie {
+ public:
+ struct Match {
+ Match() {}
+ Match(int id, int match_length) : id(id), match_length(match_length) {}
+ int id = -1;
+ int match_length = -1;
+ bool empty() const { return match_length == -1; }
+ bool operator==(const Match& m) const {
+ return m.id == id && m.match_length == match_length;
+ }
+ };
+
+ // nodes and nodes_length specify the array of the nodes of the trie.
+ explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes)
+ : nodes_(nodes) {}
+
+ // Finds matches that are prefixes of a string.
+ template <typename callback>
+ void IteratePrefixMatches(const utils::string_view& input,
+ callback update_fn) const;
+
+ // Finds the longest prefix match of a string.
+ Match LongestPrefixMatch(const utils::string_view& input) const {
+ Match match;
+ IteratePrefixMatches(input, [&match](const Match& m) { match = m; });
+ return match;
+ }
+
+ private:
+ // Returns whether a node as a leaf as a child.
+ bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; }
+
+ // Returns a value associated with a node. Available when a node is a leaf.
+ int value(uint32_t i) const {
+ return static_cast<int>(((*nodes_)[i]) & 0x7fffffff);
+ }
+
+ // Returns a label associated with a node.
+ // A leaf node will have the MSB set and thus return an invalid label.
+ int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; }
+
+ // Returns offset to children.
+ int32_t offset(uint32_t i) const {
+ const uint32_t node = (*nodes_)[i];
+ return (node >> 10) << ((node & 0x200) >> 6);
+ }
+
+ const flatbuffers::Vector<uint32_t>* nodes_;
+};
+
+template <typename callback>
+void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input,
+ callback update_fn) const {
+ if (nodes_->size() == 0) {
+ return;
+ }
+ uint32_t pos = offset(0);
+ for (int i = 0; i < input.length(); ++i) {
+ pos ^= static_cast<unsigned char>(input.at(i));
+ if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) {
+ // No match, exit.
+ return;
+ }
+ const bool node_has_leaf = has_leaf(pos);
+ pos ^= offset(pos);
+ if (pos < 0 || pos >= nodes_->size()) {
+ // We can get here only if the structure is corrupted.
+ return;
+ }
+ if (node_has_leaf) {
+ update_fn(Match(value(pos), i + 1));
+ }
+ }
+}
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc
new file mode 100644
index 00000000..72b7262b
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc
@@ -0,0 +1,81 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h"
+
+#include <algorithm>
+#include <memory>
+
+#include "include/darts.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data) {
+ std::vector<int> ids;
+ ids.reserve(data.size());
+ for (int i = 0; i < data.size(); ++i) {
+ ids.push_back(i);
+ }
+ return BuildTrie(data, ids);
+}
+
+std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
+ const std::vector<int>& ids) {
+ // We make strong assumptions about binary structure of trie.
+ struct OneElement {
+ OneElement(const std::string* key_, int index_)
+ : key(key_), index(index_) {}
+ const std::string* key;
+ int index;
+ bool operator<(const OneElement& el) const { return *key < *el.key; }
+ };
+ std::vector<OneElement> elements;
+ elements.reserve(data.size());
+ auto data_iterator = std::begin(data);
+ auto ids_iterator = std::begin(ids);
+ for (; data_iterator != std::end(data) && ids_iterator != std::end(ids);
+ ++data_iterator, ++ids_iterator) {
+ elements.emplace_back(&(*data_iterator), *ids_iterator);
+ }
+ // Sort by keys.
+ std::sort(elements.begin(), elements.end());
+
+ // Create vectors to build the trie.
+ std::vector<const char*> strings;
+ std::vector<int32_t> indexes;
+ strings.reserve(data.size());
+ indexes.reserve(data.size());
+ for (const auto& el : elements) {
+ strings.push_back(el.key->c_str());
+ indexes.push_back(el.index);
+ }
+ auto trie = std::make_unique<Darts::DoubleArray>();
+ trie->build(data.size(), const_cast<char**>(&strings[0]), nullptr,
+ &indexes[0]);
+ // We make strong assumptions about internal Darts trie structure:
+ // - it is a vector of 32 bit signed integers
+ // - the "array" is the only one structure that contains all information about
+ // the trie.
+ const uint32_t* trie_data = static_cast<const uint32_t*>(trie->array());
+ return std::vector<uint32_t>(trie_data, trie_data + trie->size());
+}
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h
new file mode 100644
index 00000000..bc618abb
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h
@@ -0,0 +1,41 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/config_generated.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data,
+ const std::vector<int>& ids);
+
+// A variant where ids are indexes in data.
+std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data);
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_test.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_test.cc
new file mode 100644
index 00000000..8a53d094
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_test.cc
@@ -0,0 +1,78 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+TEST(DoubleArrayTrieTest, Match) {
+ flatbuffers::FlatBufferBuilder builder(1024);
+ const std::vector<std::string> test_strings = {"A", "AAX", "AA", "B"};
+ const auto trie_vector = builder.CreateVector(BuildTrie(test_strings));
+ TrieBuilder trie_builder(builder);
+ trie_builder.add_nodes(trie_vector);
+ const auto pieces = trie_builder.Finish();
+ EncoderConfigBuilder ecb(builder);
+ ecb.add_pieces(pieces);
+ FinishEncoderConfigBuffer(builder, ecb.Finish());
+ const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
+ DoubleArrayTrie dat(config->pieces()->nodes());
+ EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")),
+ DoubleArrayTrie::Match(2, 2));
+
+ std::vector<DoubleArrayTrie::Match> matches;
+ dat.IteratePrefixMatches(
+ utils::string_view("AAXL"),
+ [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
+ EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1),
+ DoubleArrayTrie::Match(2, 2),
+ DoubleArrayTrie::Match(1, 3)));
+}
+
+TEST(DoubleArrayTrieTest, ComplexMatch) {
+ flatbuffers::FlatBufferBuilder builder(1024);
+ const std::vector<std::string> test_strings = {"\xe2\x96\x81the", ",", "s",
+ "\xe2\x96\x81Hello"};
+ const std::vector<int> test_ids = {0, 5, 10, 15};
+ const auto trie_vector =
+ builder.CreateVector(BuildTrie(test_strings, test_ids));
+ TrieBuilder trie_builder(builder);
+ trie_builder.add_nodes(trie_vector);
+ const auto pieces = trie_builder.Finish();
+ EncoderConfigBuilder ecb(builder);
+ ecb.add_pieces(pieces);
+ FinishEncoderConfigBuffer(builder, ecb.Finish());
+ const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
+ DoubleArrayTrie dat(config->pieces()->nodes());
+
+ std::vector<DoubleArrayTrie::Match> matches;
+ dat.IteratePrefixMatches(
+ utils::string_view("\xe2\x96\x81Hello"),
+ [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); });
+ EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8)));
+}
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config.fbs b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config.fbs
new file mode 100644
index 00000000..7f1f2bad
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config.fbs
@@ -0,0 +1,52 @@
+// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+include "config.fbs";
+
+namespace tflite.ops.custom.sentencepiece;
+
+table EncoderConfig {
+ // Version of the encoder.
+ version: EncoderVersion = SENTENCE_PIECE;
+ start_code: int32 = 0;
+ end_code: int32 = 0;
+
+ unknown_code: int32 = -1;
+ // Weight of "unknown code" when encoding. "Penalty" because it usually has a
+ // big negative weight,less than any other sentencepiece.
+ unknown_penalty: float = 0;
+
+ // The offset for encoding, usually used when codes with low codes are reserved
+ // for some special needs.
+ encoding_offset: int32;
+
+ // String pieces for encoding.
+ pieces: Trie;
+ pieces_scores: [float];
+
+ // Normalization related parameters.
+ remove_extra_whitespaces: bool;
+
+ // Add a whitespace prefix before encoding.
+ add_dummy_prefix: bool;
+
+ // Escape whitespaces during encoding so the decoder can restore them exactly as
+ // in the input.
+ escape_whitespaces: bool;
+
+ // Normalization parameters.
+ normalized_prefixes: Trie;
+ normalized_replacements: [byte];
+}
+
+root_type EncoderConfig;
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
new file mode 100644
index 00000000..73e853ff
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc
@@ -0,0 +1,197 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h"
+
+#include "absl/status/status.h"
+#include "absl/strings/str_replace.h"
+#include "src/sentencepiece_model.pb.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
+DecodePrecompiledCharsmap(
+ const ::sentencepiece::NormalizerSpec& normalizer_spec) {
+ // This function "undoes" encoding done by
+ // sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
+ const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
+ const uint32_t trie_size =
+ *reinterpret_cast<const uint32_t*>(precompiled_map);
+ const uint32_t* trie_ptr =
+ reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
+ const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
+ precompiled_map + sizeof(uint32_t) + trie_size);
+ const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
+ sizeof(uint32_t) - trie_size;
+ return std::make_tuple(
+ std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
+ std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
+}
+
+tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
+ const std::string& model_config_str, int encoding_offset) {
+ ::sentencepiece::ModelProto model_config;
+ if (!model_config.ParseFromString(model_config_str)) {
+ return absl::InvalidArgumentError(
+ "Invalid configuration, can't parse SentencePiece model config " +
+ model_config.InitializationErrorString());
+ }
+ // Convert sentencepieces.
+ std::vector<std::string> pieces;
+ pieces.reserve(model_config.pieces_size());
+ std::vector<float> scores;
+ scores.reserve(model_config.pieces_size());
+ std::vector<int> ids;
+ ids.reserve(model_config.pieces_size());
+ float min_score = 0.0;
+ int index = 0;
+ for (const auto& piece : model_config.pieces()) {
+ switch (piece.type()) {
+ case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
+ case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
+ pieces.push_back(piece.piece());
+ ids.push_back(index);
+ if (piece.score() < min_score) {
+ min_score = piece.score();
+ }
+ break;
+ case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
+ case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
+ // Ignore unknown and control codes.
+ break;
+ default:
+ return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
+ piece.piece());
+ }
+ scores.push_back(piece.score());
+ ++index;
+ }
+ flatbuffers::FlatBufferBuilder builder(1024);
+ const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
+ const auto pieces_score_vector = builder.CreateVector(scores);
+ TrieBuilder pieces_trie_builder(builder);
+ pieces_trie_builder.add_nodes(pieces_trie_vector);
+ const auto pieces_trie_fbs = pieces_trie_builder.Finish();
+
+ // Converting normalization.
+ const auto [normalization_trie, normalization_strings] =
+ DecodePrecompiledCharsmap(model_config.normalizer_spec());
+ const auto normalization_trie_vector =
+ builder.CreateVector(normalization_trie);
+ TrieBuilder normalization_trie_builder(builder);
+ normalization_trie_builder.add_nodes(normalization_trie_vector);
+ const auto normalization_trie_fbs = normalization_trie_builder.Finish();
+ const auto normalization_strings_fbs =
+ builder.CreateVector(normalization_strings);
+
+ EncoderConfigBuilder ecb(builder);
+ ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
+ ecb.add_start_code(model_config.trainer_spec().bos_id());
+ ecb.add_end_code(model_config.trainer_spec().eos_id());
+ ecb.add_unknown_code(model_config.trainer_spec().unk_id());
+ ecb.add_unknown_penalty(min_score - kUnkPenalty);
+ ecb.add_encoding_offset(encoding_offset);
+ ecb.add_pieces(pieces_trie_fbs);
+ ecb.add_pieces_scores(pieces_score_vector);
+ ecb.add_remove_extra_whitespaces(
+ model_config.normalizer_spec().remove_extra_whitespaces());
+ ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
+ ecb.add_escape_whitespaces(
+ model_config.normalizer_spec().escape_whitespaces());
+ ecb.add_normalized_prefixes(normalization_trie_fbs);
+ ecb.add_normalized_replacements(normalization_strings_fbs);
+ FinishEncoderConfigBuffer(builder, ecb.Finish());
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+tflite::support::StatusOr<std::string>
+ConvertSentencepieceModelToFlatBufferForDecoder(
+ const std::string& model_config_str, int encoding_offset) {
+ ::sentencepiece::ModelProto model_config;
+ if (!model_config.ParseFromString(model_config_str)) {
+ return absl::InvalidArgumentError(
+ "Invalid configuration, can't parse SentencePiece model config " +
+ model_config.InitializationErrorString());
+ }
+ flatbuffers::FlatBufferBuilder builder(1024);
+ // Collect sentencepieces.
+ std::vector<std::string> pieces;
+ for (const auto& piece : model_config.pieces()) {
+ // In the original library all pieces processing is done during decoding.
+ // Because it is independent from context or parameters we can do it in
+ // advance here.
+ switch (piece.type()) {
+ case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
+ case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
+ pieces.push_back(
+ absl::StrReplaceAll(piece.piece(), {{kSpaceSymbol, " "}}));
+ break;
+ case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
+ pieces.push_back(
+ kDefaultUnknownSymbol); // Always decode with the default unknown.
+ break;
+ default:
+ pieces.push_back("");
+ }
+ }
+ const auto pieces_fbs = builder.CreateVectorOfStrings(pieces);
+ DecoderConfigBuilder decb(builder);
+
+ decb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
+ decb.add_encoding_offset(encoding_offset);
+ decb.add_decode_pieces(pieces_fbs);
+ decb.add_remove_dummy_prefix(
+ model_config.normalizer_spec().add_dummy_prefix());
+
+ FinishDecoderConfigBuffer(builder, decb.Finish());
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+int GetVocabularySize(const std::string& model_string) {
+ const EncoderConfig* config = GetEncoderConfig(model_string.data());
+ return config->pieces_scores()->size() + config->encoding_offset();
+}
+
+std::string ConvertSentencepieceModel(const std::string& model_string) {
+ const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
+ // TODO(mgubin): Propogate error to the Python code and throw correct
+ // exception.
+ assert(result.status().ok());
+ return result.value();
+}
+
+std::string ConvertSentencepieceModelForDecoder(
+ const std::string& model_string) {
+ const auto result =
+ ConvertSentencepieceModelToFlatBufferForDecoder(model_string);
+ // TODO(mgubin): Propogate error to the Python code and throw correct
+ // exception.
+ assert(result.status().ok());
+ return result.value();
+}
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
new file mode 100644
index 00000000..5687b628
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_MODEL_CONVERTER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_MODEL_CONVERTER_H_
+#include <string>
+
+#include "tensorflow_lite_support/cc/port/statusor.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+// Converts Sentencepiece configuration to flatbuffer format.
+// encoding_offset is used by some encoders that combine different encodings.
+tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
+ const std::string& model_config_str, int encoding_offset = 0);
+
+// Converts Sentencepiece configuration to flatbuffer format for encoder.
+// encoding_offset is used by some encoders that combine different encodings.
+tflite::support::StatusOr<std::string>
+ConvertSentencepieceModelToFlatBufferForDecoder(
+ const std::string& model_config_str, int encoding_offset = 0);
+
+// The functions that are provided for the Python wrapper.
+std::string ConvertSentencepieceModel(const std::string& model_string);
+std::string ConvertSentencepieceModelForDecoder(
+ const std::string& model_string);
+
+// Returns size of a vocabulary from Sentencepiece configuration in flatbuffer
+// format.
+int GetVocabularySize(const std::string& model_string);
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_MODEL_CONVERTER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl
new file mode 100644
index 00000000..87695a46
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl
@@ -0,0 +1,86 @@
+"""Build definitions supporting platform-independent native build."""
+
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_copts", "tf_opts_nortti_if_android")
+load("@bazel_skylib//lib:selects.bzl", "selects")
+
+def micore_if(android, ios = [], default = []):
+ """Helper to create a select.
+
+ Args:
+ android: what to return if compiling for Android.
+ ios: what to return if compiling for iOS.
+ default: what to return otherwise.
+ Returns:
+ the `android` list for Android compilation and the
+ `default` list otherwise.
+ """
+ return select({
+ ":android": android,
+ ":apple": ios,
+ "//conditions:default": default,
+ })
+
+def micore_tf_copts():
+ """C options for Tensorflow builds.
+
+ Returns:
+ a list of copts which must be used by each cc_library which
+ refers to Tensorflow. Enables the library to compile both for
+ Android and for Linux.
+ """
+ return tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [
+ "-Wno-narrowing",
+ "-Wno-sign-compare",
+ "-Wno-overloaded-virtual",
+ ] + micore_if(
+ android = [
+ # Set a define so Tensorflow's register_types.h
+ # adopts to support a rich set of types, to be pruned by
+ # selective registration.
+ "-DSUPPORT_SELECTIVE_REGISTRATION",
+ # Selective registration uses constexprs with recursive
+ # string comparisons; that can lead to compiler errors, so
+ # we increase the constexpr recursion depth.
+ "-fconstexpr-depth=1024",
+ ],
+ ) + selects.with_or({
+ # If building for armeabi-v7a, and if compilation_mode is 'fastbuild'
+ # or 'dbg' then forcefully add -Oz to the list compiler options.
+ # Without it, some TF dependencies can't build (b/112286436). If
+ # compilation_mode is 'opt' then rely on the toolchain default.
+ (
+ ":armeabi_v7a_and_fastbuild",
+ ":armeabi_v7a_and_dbg",
+ ): ["-Oz"],
+ "//conditions:default": [],
+ })
+
+def micore_tf_deps():
+ """Dependencies for Tensorflow builds.
+
+ Returns:
+ list of dependencies which must be used by each cc_library
+ which refers to Tensorflow. Enables the library to compile both for
+ Android and for Linux. Use this macro instead of directly
+ declaring dependencies on Tensorflow.
+ """
+ return micore_if(
+ android = [
+ # Link to library which does not contain any ops.
+ "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite",
+ "@gemmlowp//:eight_bit_int_gemm",
+ "@fft2d//:fft2d",
+ ],
+ ios = [
+ "@org_tensorflow//tensorflow/core:portable_tensorflow_lib",
+ "@gemmlowp//:eight_bit_int_gemm",
+ "@fft2d//:fft2d",
+ ],
+ default = [
+ # Standard references for Tensorflow when building for Linux. We use
+ # an indirection via the alias targets below, to facilitate whitelisting
+ # these deps in the mobile license presubmit checks.
+ "@local_config_tf//:libtensorflow_framework",
+ "@local_config_tf//:tf_header_lib",
+ ],
+ )
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.cc
new file mode 100644
index 00000000..86e186da
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.cc
@@ -0,0 +1,63 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h"
+
+#include <string>
+#include <tuple>
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+DecoderResult DecodeString(const std::vector<int>& encoded,
+ const void* config_buffer) {
+ DecoderResult result;
+
+ // Get the config from the buffer.
+ const DecoderConfig* config = GetDecoderConfig(config_buffer);
+ if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
+ result.type = DecoderResultType::WRONG_CONFIG;
+ return result;
+ }
+ bool remove_dummy_prefix = config->remove_dummy_prefix();
+ const auto config_pieces = config->decode_pieces();
+ for (const auto code : encoded) {
+ const int real_code = code - config->encoding_offset();
+ if (real_code >= config_pieces->size()) {
+ result.type = DecoderResultType::INVALID_INPUT;
+ return result;
+ }
+ const auto& piece_text = config_pieces->GetAsString(real_code);
+ const char* piece_str = piece_text->c_str();
+ if (remove_dummy_prefix && *piece_str == ' ') {
+ ++piece_str;
+ }
+ result.decoded.append(piece_str);
+ remove_dummy_prefix = false;
+ }
+ // TODO(mgubin): Denormalize the string, haven't seen any Sentencepiece model
+ // with a denormalizer.
+ return result;
+}
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h
new file mode 100644
index 00000000..a4424687
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h
@@ -0,0 +1,50 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_DECODER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_DECODER_H_
+
+// Sentencepiece decoder optimized with memmapped model.
+
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+enum class DecoderResultType {
+ SUCCESS = 0,
+ WRONG_CONFIG = 1,
+ INVALID_INPUT = 2
+};
+
+struct DecoderResult {
+ DecoderResultType type = DecoderResultType::SUCCESS;
+ std::string decoded;
+};
+
+// Decodes one string from a vector of id. Takes the configuration as a
+// type-erased buffer.
+DecoderResult DecodeString(const std::vector<int>& encoded,
+ const void* config_buffer);
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_DECODER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
new file mode 100644
index 00000000..04d1c85a
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h"
+
+#include <fstream>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "src/sentencepiece.pb.h"
+#include "src/sentencepiece_processor.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+namespace internal {
+
+tensorflow::Status TFReadFileToString(const std::string& filepath,
+ std::string* data) {
+ return tensorflow::ReadFileToString(tensorflow::Env::Default(),
+ /*test_path*/ filepath, data);
+}
+
+absl::Status StdReadFileToString(const std::string& filepath,
+ std::string* data) {
+ std::ifstream infile(filepath);
+ if (!infile.is_open()) {
+ return absl::NotFoundError(
+ absl::StrFormat("Error when opening %s", filepath));
+ }
+ std::string contents((std::istreambuf_iterator<char>(infile)),
+ (std::istreambuf_iterator<char>()));
+ data->append(contents);
+ infile.close();
+ return absl::OkStatus();
+}
+
+} // namespace internal
+
+namespace {
+static char kConfigFilePath[] =
+ "tensorflow_lite_support/custom_ops/kernel/"
+ "sentencepiece/testdata/sentencepiece.model";
+
+TEST(OptimizedEncoder, ConfigConverter) {
+ std::string config;
+ auto status = internal::StdReadFileToString(kConfigFilePath, &config);
+
+ ASSERT_TRUE(status.ok());
+
+ ::sentencepiece::SentencePieceProcessor processor;
+ ASSERT_OK(processor.LoadFromSerializedProto(config));
+ const auto converted_model = ConvertSentencepieceModelForDecoder(config);
+ const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
+ ::sentencepiece::SentencePieceText reference_encoded;
+ CHECK_OK(processor.Encode(test_string, &reference_encoded));
+
+ std::vector<int> encoded_vector;
+ encoded_vector.reserve(reference_encoded.pieces_size());
+ for (const auto& piece : reference_encoded.pieces()) {
+ encoded_vector.push_back(piece.id());
+ }
+ std::string ref_decoded;
+ ASSERT_OK(processor.Decode(encoded_vector, &ref_decoded));
+ const auto decoded = DecodeString(encoded_vector, converted_model.data());
+ ASSERT_EQ(decoded.type, DecoderResultType::SUCCESS);
+ ASSERT_EQ(ref_decoded, decoded.decoded);
+}
+} // namespace
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
new file mode 100644
index 00000000..5a59ee48
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc
@@ -0,0 +1,239 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
+
+#include <algorithm>
+#include <tuple>
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+namespace {
+
+const char kSpaceSymbol[] = "\xe2\x96\x81";
+
+template <typename processing_callback>
+std::tuple<std::string, std::vector<int>> process_string(
+ const std::string& input, const std::vector<int>& offsets,
+ const processing_callback& pc) {
+ std::string result_string;
+ result_string.reserve(input.size());
+ std::vector<int> result_offsets;
+ result_offsets.reserve(offsets.size());
+ for (int i = 0, j = 0; i < input.size();) {
+ auto [consumed, new_string] = pc(input.data() + i, input.size() - i);
+ if (consumed == 0) {
+ // Skip the current byte and move forward.
+ result_string.push_back(input[i]);
+ result_offsets.push_back(offsets[j]);
+ i++;
+ j++;
+ continue;
+ }
+ result_string.append(new_string.data(), new_string.length());
+ for (int i = 0; i < new_string.length(); ++i) {
+ result_offsets.push_back(offsets[j]);
+ }
+ j += consumed;
+ i += consumed;
+ }
+ return std::make_tuple(result_string, result_offsets);
+}
+
+inline char is_whitespace(char c) {
+ return c == ' ' || c == '\t' || c == '\r' || c == '\n';
+}
+
+std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
+ int len) {
+ if (len == 0 || !is_whitespace(*data)) {
+ return std::make_tuple(0, utils::string_view(nullptr, 0));
+ }
+ int num_consumed = 1;
+ for (; num_consumed < len && is_whitespace(data[num_consumed]);
+ ++num_consumed) {
+ }
+ return num_consumed > 1
+ ? std::make_tuple(num_consumed, utils::string_view(" ", 1))
+ : std::make_tuple(0, utils::string_view(nullptr, 0));
+}
+
+std::tuple<int, utils::string_view> find_replacement(
+ const char* data, int len, const DoubleArrayTrie& dat,
+ const flatbuffers::Vector<int8_t>& replacements) {
+ const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
+ if (!max_match.empty()) {
+ // Because flatbuffer byte is signed char which is not the same as char,
+ // there is the reinterpret_cast here.
+ const char* replaced_string_ptr =
+ reinterpret_cast<const char*>(replacements.data() + max_match.id);
+ return std::make_tuple(max_match.match_length,
+ utils::string_view(replaced_string_ptr));
+ }
+ return std::make_tuple(0, utils::string_view(nullptr, 0));
+}
+} // namespace
+
+std::tuple<std::string, std::vector<int>> NormalizeString(
+ const std::string& in_string, const EncoderConfig& config) {
+ std::vector<int> output_offsets;
+ std::string result = in_string;
+ output_offsets.reserve(in_string.length());
+ for (int i = 0; i < in_string.length(); ++i) {
+ output_offsets.push_back(i);
+ }
+ if (in_string.empty()) {
+ return std::make_tuple(result, output_offsets);
+ }
+ if (config.add_dummy_prefix()) {
+ result.insert(result.begin(), ' ');
+ output_offsets.insert(output_offsets.begin(), 0);
+ }
+ // Greedely replace normalized_prefixes with normalized_replacements
+ if (config.normalized_prefixes() != nullptr &&
+ config.normalized_replacements() != nullptr) {
+ const DoubleArrayTrie normalized_prefixes_matcher(
+ config.normalized_prefixes()->nodes());
+ const auto norm_replace = [&config, &normalized_prefixes_matcher](
+ const char* data, int len) {
+ return find_replacement(data, len, normalized_prefixes_matcher,
+ *config.normalized_replacements());
+ };
+ std::tie(result, output_offsets) =
+ process_string(result, output_offsets, norm_replace);
+ }
+ if (config.remove_extra_whitespaces()) {
+ std::tie(result, output_offsets) =
+ process_string(result, output_offsets, remove_extra_whitespaces);
+ if (!result.empty() && is_whitespace(result.back())) {
+ result.pop_back();
+ output_offsets.pop_back();
+ }
+ }
+ if (config.escape_whitespaces()) {
+ const auto replace_whitespaces = [](const char* data, int len) {
+ if (len > 0 && is_whitespace(*data)) {
+ return std::make_tuple(1, utils::string_view(kSpaceSymbol));
+ }
+ return std::make_tuple(0, utils::string_view(nullptr, 0));
+ };
+ std::tie(result, output_offsets) =
+ process_string(result, output_offsets, replace_whitespaces);
+ }
+
+ return std::make_tuple(result, output_offsets);
+}
+
+EncoderResult EncodeNormalizedString(const std::string& str,
+ const std::vector<int>& offsets,
+ const EncoderConfig& config, bool add_bos,
+ bool add_eos, bool reverse) {
+ const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
+ const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
+ const int unknown_code = config.unknown_code();
+ const float unknown_penalty = config.unknown_penalty();
+ struct LatticeElement {
+ float score = 0;
+ int code = -1;
+ int prev_position = -1;
+ LatticeElement(float score_, int code_, int prev_position_)
+ : score(score_), code(code_), prev_position(prev_position_) {}
+ LatticeElement() {}
+ };
+ const int length = str.length();
+ std::vector<LatticeElement> lattice(length + 1);
+ for (int i = 0; i < length; ++i) {
+ if (i > 0 && lattice[i].prev_position < 0) {
+ // This state is unreachable.
+ continue;
+ }
+ if (unknown_code >= 0) {
+ // Put unknown code.
+ const float penalized_score = lattice[i].score + unknown_penalty;
+ const int pos = i + 1;
+ LatticeElement& current_element = lattice[pos];
+ if (current_element.prev_position < 0 ||
+ current_element.score < penalized_score) {
+ current_element = LatticeElement(
+ penalized_score, unknown_code,
+ // If the current state is already reached by unknown code, merge
+ // states.
+ lattice[i].code == unknown_code ? lattice[i].prev_position : i);
+ }
+ }
+ auto lattice_update = [&lattice, i,
+ piece_scores](const DoubleArrayTrie::Match& m) {
+ LatticeElement& target_element = lattice[i + m.match_length];
+ const float score = lattice[i].score + (*piece_scores)[m.id];
+ if (target_element.prev_position < 0 || target_element.score < score) {
+ target_element = LatticeElement(score, m.id, i);
+ }
+ };
+ piece_matcher.IteratePrefixMatches(
+ utils::string_view(str.data() + i, length - i), lattice_update);
+ }
+
+ EncoderResult result;
+ if (add_eos) {
+ result.codes.push_back(config.end_code());
+ result.offsets.push_back(length);
+ }
+ if (lattice[length].prev_position >= 0) {
+ for (int pos = length; pos > 0;) {
+ auto code = lattice[pos].code;
+ if (code != config.unknown_code()) {
+ code += config.encoding_offset();
+ }
+ result.codes.push_back(code);
+ pos = lattice[pos].prev_position;
+ result.offsets.push_back(offsets[pos]);
+ }
+ }
+ if (add_bos) {
+ result.codes.push_back(config.start_code());
+ result.offsets.push_back(0);
+ }
+ if (!reverse) {
+ std::reverse(result.codes.begin(), result.codes.end());
+ std::reverse(result.offsets.begin(), result.offsets.end());
+ }
+ return result;
+}
+
+EncoderResult EncodeString(const std::string& string, const void* config_buffer,
+ bool add_bos, bool add_eos, bool reverse) {
+ // Get the config from the buffer.
+ const EncoderConfig* config = GetEncoderConfig(config_buffer);
+ if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
+ EncoderResult result;
+ result.type = EncoderResultType::WRONG_CONFIG;
+ return result;
+ }
+ std::string normalized_string;
+ std::vector<int> offsets;
+ std::tie(normalized_string, offsets) = NormalizeString(string, *config);
+ return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
+ add_eos, reverse);
+}
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
new file mode 100644
index 00000000..44d6e88f
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
+
+// Sentencepiece encoder optimized with memmapped model.
+
+#include <string>
+#include <tuple>
+#include <vector>
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
+
+struct EncoderResult {
+ EncoderResultType type = EncoderResultType::SUCCESS;
+ std::vector<int> codes;
+ std::vector<int> offsets;
+};
+std::tuple<std::string, std::vector<int>> NormalizeString(
+ const std::string& in_string, const EncoderConfig& config);
+
+// Encodes one string and returns ids and offsets. Takes the configuration as a
+// type-erased buffer.
+EncoderResult EncodeString(const std::string& string, const void* config_buffer,
+ bool add_bos, bool add_eos, bool reverse);
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
new file mode 100644
index 00000000..ad3cd27f
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc
@@ -0,0 +1,167 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
+
+#include <fstream>
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "src/sentencepiece.pb.h"
+#include "src/sentencepiece_processor.h"
+#include "tensorflow/core/platform/env.h"
+
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+namespace internal {
+
+tensorflow::Status TFReadFileToString(
+ const std::string& filepath, std::string* data) {
+ return tensorflow::ReadFileToString(
+ tensorflow::Env::Default(), /*test_path*/ filepath, data);
+}
+
+absl::Status StdReadFileToString(
+ const std::string& filepath, std::string* data) {
+ std::ifstream infile(filepath);
+ if (!infile.is_open()) {
+ return absl::NotFoundError(
+ absl::StrFormat("Error when opening %s", filepath));
+ }
+ std::string contents((std::istreambuf_iterator<char>(infile)),
+ (std::istreambuf_iterator<char>()));
+ data->append(contents);
+ infile.close();
+ return absl::OkStatus();
+}
+} // namespace internal
+
+namespace {
+
+static char kConfigFilePath[] =
+ "tensorflow_lite_support/custom_ops/kernel/"
+ "sentencepiece/testdata/sentencepiece.model";
+
+TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
+ flatbuffers::FlatBufferBuilder builder(1024);
+ EncoderConfigBuilder ecb(builder);
+ ecb.add_remove_extra_whitespaces(true);
+ ecb.add_add_dummy_prefix(true);
+ ecb.add_escape_whitespaces(true);
+ FinishEncoderConfigBuffer(builder, ecb.Finish());
+ const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
+ {
+ const auto [res_string, offsets] = NormalizeString("x y", *config);
+ EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
+ EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
+ }
+ {
+ const auto [res_string, offsets] = NormalizeString("\tx y\n", *config);
+ EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
+ EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
+ }
+}
+
+TEST(OptimizedEncoder, NormalizeStringReplacement) {
+ flatbuffers::FlatBufferBuilder builder(1024);
+ const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
+ const char norm_replacements[] = "A1\0A2\0A3\0A4";
+ const auto trie_vector =
+ builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
+ const auto norm_r = builder.CreateVector<int8_t>(
+ reinterpret_cast<const signed char*>(norm_replacements),
+ sizeof(norm_replacements));
+ TrieBuilder trie_builder(builder);
+ trie_builder.add_nodes(trie_vector);
+ const auto norm_p = trie_builder.Finish();
+ EncoderConfigBuilder ecb(builder);
+ ecb.add_remove_extra_whitespaces(false);
+ ecb.add_normalized_prefixes(norm_p);
+ ecb.add_normalized_replacements(norm_r);
+ FinishEncoderConfigBuffer(builder, ecb.Finish());
+ const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
+ {
+ const auto [res_string, offsets] =
+ NormalizeString("ABAABAAABAAAA", *config);
+ EXPECT_EQ(res_string, "A1BA2BA3BA4");
+ EXPECT_THAT(offsets,
+ ::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
+ }
+}
+
+TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
+ flatbuffers::FlatBufferBuilder builder(1024);
+ const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
+ "X"};
+ const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
+ const auto trie_vector =
+ builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
+ const auto norm_r = builder.CreateVector<int8_t>(
+ reinterpret_cast<const signed char*>(norm_replacements),
+ sizeof(norm_replacements));
+ TrieBuilder trie_builder(builder);
+ trie_builder.add_nodes(trie_vector);
+ const auto norm_p = trie_builder.Finish();
+ EncoderConfigBuilder ecb(builder);
+ ecb.add_remove_extra_whitespaces(true);
+ ecb.add_normalized_prefixes(norm_p);
+ ecb.add_normalized_replacements(norm_r);
+ FinishEncoderConfigBuffer(builder, ecb.Finish());
+ const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
+ {
+ const auto [res_string, offsets] =
+ NormalizeString("XXABAABAAABAAAA", *config);
+ EXPECT_EQ(res_string, " A1BA2BA3BA4");
+ EXPECT_THAT(offsets,
+ ::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
+ }
+}
+
+TEST(OptimizedEncoder, ConfigConverter) {
+ std::string config;
+ auto status = internal::StdReadFileToString(kConfigFilePath, &config);
+ ASSERT_TRUE(status.ok());
+
+ ::sentencepiece::SentencePieceProcessor processor;
+ ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
+ const auto converted_model = ConvertSentencepieceModel(config);
+ const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
+ const auto encoded =
+ EncodeString(test_string, converted_model.data(), false, false, false);
+ ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
+
+ ::sentencepiece::SentencePieceText reference_encoded;
+ ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
+ EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
+ for (int i = 0; i < encoded.codes.size(); ++i) {
+ EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
+ EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
+ }
+}
+
+} // namespace
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.cc
new file mode 100644
index 00000000..5345409f
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.cc
@@ -0,0 +1,34 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
+TfLiteRegistration* Register_SENTENCEPIECE_DETOKENIZER();
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+extern "C" void TFLite_SentencepieceTokenizerRegisterer(
+ tflite::MutableOpResolver* resolver) {
+ resolver->AddCustom("TFSentencepieceTokenizeOp",
+ tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
+ resolver->AddCustom(
+ "TFSentencepieceDetokenizeOp",
+ tflite::ops::custom::Register_SENTENCEPIECE_DETOKENIZER());
+}
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
new file mode 100644
index 00000000..deb4e4ee
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h
@@ -0,0 +1,25 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+// C-function that is called from the Python Wrapper.
+
+extern "C" void TFLite_SentencepieceTokenizerRegisterer(
+ tflite::MutableOpResolver *resolver);
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h
new file mode 100644
index 00000000..55644ba6
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h
@@ -0,0 +1,43 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+// The constant is copied from
+// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
+constexpr float kUnkPenalty = 10.0;
+
+// These constants are copied from
+// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
+//
+// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
+constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
+
+// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
+// since this character can be useful both for user and
+// developer. We can easily figure out that <unk> is emitted.
+constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
+
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h
new file mode 100644
index 00000000..1f4e0f4d
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_
+
+// Constants are shared between TF and TFLite SentencepieceTokenizer kernels.
+namespace tensorflow {
+namespace ops {
+constexpr int kSPModelIndex = 0;
+constexpr int kInputIndex = 1;
+constexpr int kInputSplits = 2;
+constexpr int kAddBOSInput = 4;
+constexpr int kAddEOSInput = 5;
+constexpr int kReverseInput = 6;
+} // namespace ops
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_op.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_op.cc
new file mode 100644
index 00000000..bd4b5a17
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_op.cc
@@ -0,0 +1,94 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/protobuf/error_codes.pb.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h"
+
+namespace tensorflow {
+namespace ops {
+REGISTER_OP("TFSentencepieceDetokenizeOp")
+ .Input("sp_model: uint8")
+ .Input("input_values: int32")
+ .Input("input_splits: Tsplits")
+ .Attr("Tsplits: {int32, int64} = DT_INT64")
+ .Output("output: string")
+ .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
+
+ shape_inference::DimensionHandle dim;
+ TF_RETURN_IF_ERROR(c->Subtract(c->NumElements(c->input(2)), 1, &dim));
+ c->set_output(0, c->Vector(dim));
+ return Status::OK();
+ });
+
+template <typename Tsplits>
+class TFSentencepieceDetokenizerOp : public tensorflow::OpKernel {
+ public:
+ explicit TFSentencepieceDetokenizerOp(tensorflow::OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+ void Compute(tensorflow::OpKernelContext* ctx) override {
+ const auto& model_tensor = ctx->input(kSPModelIndex);
+ const auto& input_values_tensor = ctx->input(kInputIndex);
+ const auto input_values_flat =
+ input_values_tensor.flat<tensorflow::int32>();
+ const auto& input_splits_tensor = ctx->input(kInputSplits);
+ const auto input_splits_flat = input_splits_tensor.flat<Tsplits>();
+ const int num_of_sentences = input_splits_flat.size() - 1;
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, {num_of_sentences}, &output_tensor));
+ auto output_flat = output_tensor->flat<tensorflow::tstring>();
+ std::vector<int> codes_for_split;
+ int input_offset = 0;
+ for (int i = 0; i < num_of_sentences; i++) {
+ // Create a vector of int32 from input according to spans.
+ const int split_size = input_splits_flat(i + 1) - input_splits_flat(i);
+ codes_for_split.clear();
+ codes_for_split.reserve(split_size);
+ for (int j = 0; j < split_size; ++j) {
+ codes_for_split.push_back(input_values_flat(input_offset++));
+ }
+ const auto res = tflite::ops::custom::sentencepiece::DecodeString(
+ codes_for_split, model_tensor.data());
+ OP_REQUIRES(
+ ctx,
+ res.type ==
+ tflite::ops::custom::sentencepiece::DecoderResultType::SUCCESS,
+ tensorflow::Status(tensorflow::error::INTERNAL,
+ "Sentencepiece conversion failed"));
+ output_flat(i) = res.decoded;
+ }
+ }
+};
+} // namespace ops
+} // namespace tensorflow
+
+REGISTER_KERNEL_BUILDER(
+ Name("TFSentencepieceDetokenizeOp")
+ .Device(tensorflow::DEVICE_CPU)
+ .TypeConstraint<tensorflow::int32>("Tsplits"),
+ tensorflow::ops::TFSentencepieceDetokenizerOp<tensorflow::int32>);
+REGISTER_KERNEL_BUILDER(
+ Name("TFSentencepieceDetokenizeOp")
+ .Device(tensorflow::DEVICE_CPU)
+ .TypeConstraint<tensorflow::int64>("Tsplits"),
+ tensorflow::ops::TFSentencepieceDetokenizerOp<tensorflow::int64>);
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
new file mode 100644
index 00000000..54b34e4e
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc
@@ -0,0 +1,100 @@
+// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+/**
+ * Sentencepiece tflite detokenizer implementation.
+ */
+#include <algorithm>
+#include <iterator>
+
+#include "flatbuffers/flexbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_util.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+namespace detokenizer {
+
+constexpr int kOutputValuesInd = 0;
+// Initializes text encoder object from serialized parameters.
+void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
+ size_t /*length*/) {
+ return nullptr;
+}
+void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // TODO(mgubin): Add checks for input and output tensors.
+ TfLiteTensor& output_values =
+ context->tensors[node->outputs->data[kOutputValuesInd]];
+ SetTensorToDynamic(&output_values);
+ // TODO(mgubin): Check input types.
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor& model_tensor =
+ context->tensors[node->inputs->data[tensorflow::ops::kSPModelIndex]];
+ const auto model_buffer_data = model_tensor.data.data;
+ const TfLiteTensor& input_encoded =
+ context->tensors[node->inputs->data[tensorflow::ops::kInputIndex]];
+ const int32_t* input_encoded_data = input_encoded.data.i32;
+ const TfLiteTensor& input_splits =
+ context->tensors[node->inputs->data[tensorflow::ops::kInputSplits]];
+ const int num_of_sentences = NumElements(input_splits.dims) - 1;
+ const int32_t* input_splits_data = input_splits.data.i32;
+
+ DynamicBuffer buf;
+
+ std::vector<int> codes_for_split;
+ int input_offset = 0;
+ for (int i = 0; i < num_of_sentences; i++) {
+ // Create a vector of int32 from input according to spans.
+ const int split_size = input_splits_data[i + 1] - input_splits_data[i];
+ codes_for_split.clear();
+ std::copy(input_encoded_data + input_offset,
+ input_encoded_data + input_offset + split_size,
+ std::back_inserter(codes_for_split));
+ const auto res = DecodeString(codes_for_split, model_buffer_data);
+ TF_LITE_ENSURE_MSG(context, res.type == DecoderResultType::SUCCESS,
+ "Sentencepiece decoding failed");
+ buf.AddString(res.decoded.data(), res.decoded.length());
+ input_offset += split_size;
+ }
+ TfLiteTensor& output_values =
+ context->tensors[node->outputs->data[kOutputValuesInd]];
+ buf.WriteToTensor(&output_values, nullptr);
+ return kTfLiteOk;
+}
+} // namespace detokenizer
+} // namespace sentencepiece
+
+TfLiteRegistration* Register_SENTENCEPIECE_DETOKENIZER() {
+ static TfLiteRegistration r = {
+ sentencepiece::detokenizer::Initialize, sentencepiece::detokenizer::Free,
+ sentencepiece::detokenizer::Prepare, sentencepiece::detokenizer::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h
new file mode 100644
index 00000000..cb3ee07f
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_
+
+// Constants are shared between TF and TFLite SentencepieceTokenizer kernels.
+namespace tensorflow {
+namespace ops {
+
+constexpr int kSPModelIndex = 0;
+constexpr int kInputIndex = 1;
+constexpr int kAddBOSInput = 4;
+constexpr int kAddEOSInput = 5;
+constexpr int kReverseInput = 6;
+} // namespace ops
+} // namespace tensorflow
+
+#endif // sTENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
new file mode 100644
index 00000000..41fc5aa2
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc
@@ -0,0 +1,119 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iterator>
+#include <vector>
+
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/protobuf/error_codes.pb.h"
+
+namespace tensorflow {
+namespace ops{
+
+// copied from third_party/tensorflow_text/core/ops/sentencepiece_ops.cc
+REGISTER_OP("TFSentencepieceTokenizeOp")
+ .Input("sp_model: uint8")
+ .Input("input: string")
+ .Input("nbest_size: int32")
+ .Input("alpha: float")
+ .Input("add_bos: bool")
+ .Input("add_eos: bool")
+ .Input("reverse: bool")
+ .Attr("out_type: {int32, string} = DT_INT32")
+ .Attr("Tsplits: {int32, int64} = DT_INT32")
+ .Output("output_values: out_type")
+ .Output("output_splits: Tsplits")
+ .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
+ tensorflow::shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
+
+ c->set_output(
+ 0, c->Vector(
+ tensorflow::shape_inference::InferenceContext::kUnknownDim));
+
+ tensorflow::shape_inference::DimensionHandle num_splits;
+ TF_RETURN_IF_ERROR(c->Add(c->NumElements(c->input(1)), 1, &num_splits));
+ c->set_output(1, c->Vector(num_splits));
+ return tensorflow::Status::OK();
+ });
+
+class TFSentencepieceOp : public tensorflow::OpKernel {
+ public:
+ explicit TFSentencepieceOp(tensorflow::OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+ void Compute(tensorflow::OpKernelContext* ctx) override {
+ const auto& model_tensor = ctx->input(kSPModelIndex);
+ const auto& input_values_tensor = ctx->input(kInputIndex);
+ const auto input_values_flat =
+ input_values_tensor.flat<tensorflow::tstring>();
+ const int num_of_input_values = input_values_flat.size();
+
+ const auto& add_bos_tensor = ctx->input(kAddBOSInput);
+ const bool add_bos = add_bos_tensor.scalar<bool>()();
+ const auto& add_eos_tensor = ctx->input(kAddEOSInput);
+ const bool add_eos = add_eos_tensor.scalar<bool>()();
+ const auto& reverse_tensor = ctx->input(kReverseInput);
+ const bool reverse = reverse_tensor.scalar<bool>()();
+
+ std::vector<int32> encoded;
+ std::vector<int32> splits;
+ for (int i = 0; i < num_of_input_values; ++i) {
+ const auto res = ::tflite::ops::custom::sentencepiece::EncodeString(
+ input_values_flat(i), model_tensor.data(), add_bos, add_eos, reverse);
+ OP_REQUIRES(
+ ctx,
+ res.type ==
+ ::tflite::ops::custom::sentencepiece::EncoderResultType::SUCCESS,
+ tensorflow::Status(tensorflow::error::INTERNAL,
+ "Sentencepiece conversion failed"));
+ std::copy(res.codes.begin(), res.codes.end(),
+ std::back_inserter(encoded));
+ splits.emplace_back(encoded.size());
+ }
+ tensorflow::Tensor* output_values_tensor = nullptr;
+ tensorflow::Tensor* output_splits_tensor = nullptr;
+
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, {encoded.size()}, &output_values_tensor));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {splits.size() + 1},
+ &output_splits_tensor));
+
+ auto values_tensor_flat = output_values_tensor->vec<int32>();
+ auto splits_tensor_flat = output_splits_tensor->vec<int32>();
+ for (int i = 0; i < encoded.size(); ++i) {
+ values_tensor_flat(i) = encoded[i];
+ }
+ splits_tensor_flat(0) = 0;
+ for (int i = 0; i < splits.size(); ++i) {
+ splits_tensor_flat(i + 1) = splits[i];
+ }
+ }
+};
+
+} // namespace ops
+} // namespace tensorflow
+REGISTER_KERNEL_BUILDER(
+ Name("TFSentencepieceTokenizeOp").Device(tensorflow::DEVICE_CPU),
+ tensorflow::ops::TFSentencepieceOp);
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
new file mode 100644
index 00000000..8309a6a2
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc
@@ -0,0 +1,129 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/**
+ * Sentencepiece tflite tokenizer implementation.
+ */
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h"
+#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h"
+#include "flatbuffers/flexbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+namespace tokenizer {
+
+constexpr int kOutputValuesInd = 0;
+constexpr int kOutputSplitsInd = 1;
+
+namespace {
+TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
+ TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
+ int index = 0;
+ for (const int size : sizes) {
+ array_size->data[index++] = size;
+ }
+ return array_size;
+}
+} // namespace
+
+// Initializes text encoder object from serialized parameters.
+void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
+ size_t /*length*/) {
+ return nullptr;
+}
+void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // TODO(mgubin): Add checks for input and output tensors.
+ TfLiteTensor& output_values =
+ context->tensors[node->outputs->data[kOutputValuesInd]];
+ SetTensorToDynamic(&output_values);
+
+ TfLiteTensor& output_splits =
+ context->tensors[node->outputs->data[kOutputSplitsInd]];
+ SetTensorToDynamic(&output_splits);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor& model_tensor =
+ context->tensors[node->inputs->data[tensorflow::ops::kSPModelIndex]];
+ const auto model_buffer_data = model_tensor.data.data;
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[tensorflow::ops::kInputIndex]];
+
+ const TfLiteTensor add_bos_tensor =
+ context->tensors[node->inputs->data[tensorflow::ops::kAddBOSInput]];
+ const bool add_bos = add_bos_tensor.data.b[0];
+ const TfLiteTensor add_eos_tensor =
+ context->tensors[node->inputs->data[tensorflow::ops::kAddEOSInput]];
+ const bool add_eos = add_eos_tensor.data.b[0];
+ const TfLiteTensor reverse_tensor =
+ context->tensors[node->inputs->data[tensorflow::ops::kReverseInput]];
+ const bool reverse = reverse_tensor.data.b[0];
+
+ std::vector<int32> encoded;
+ std::vector<int32> splits;
+ const int num_strings = tflite::GetStringCount(&input_text);
+ for (int i = 0; i < num_strings; ++i) {
+ const auto strref = tflite::GetString(&input_text, i);
+ const auto res = EncodeString(std::string(strref.str, strref.len),
+ model_buffer_data, add_bos, add_eos, reverse);
+ TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
+ "Sentencepiece conversion failed");
+ std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
+ splits.emplace_back(encoded.size());
+ }
+
+ TfLiteTensor& output_values =
+ context->tensors[node->outputs->data[kOutputValuesInd]];
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(
+ context, &output_values,
+ CreateSizeArray({static_cast<int>(encoded.size())})));
+ int32_t* output_values_flat = output_values.data.i32;
+ std::copy(encoded.begin(), encoded.end(), output_values_flat);
+ TfLiteTensor& output_splits =
+ context->tensors[node->outputs->data[kOutputSplitsInd]];
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(
+ context, &output_splits,
+ CreateSizeArray({static_cast<int>(splits.size() + 1)})));
+ int32_t* output_splits_flat = output_splits.data.i32;
+ *output_splits_flat = 0;
+ std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
+ return kTfLiteOk;
+}
+} // namespace tokenizer
+} // namespace sentencepiece
+
+TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
+ static TfLiteRegistration r = {
+ sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
+ sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/testdata/sentencepiece.model b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/testdata/sentencepiece.model
new file mode 100644
index 00000000..041188ff
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/testdata/sentencepiece.model
Binary files differ
diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h
new file mode 100644
index 00000000..13bc021e
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h
@@ -0,0 +1,66 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_UTILS_H_
+
+#include <ostream>
+#include <string>
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sentencepiece {
+
+// AOSP and WASM doesn't support string_view,
+// we put here a minimal re-implementation.
+namespace utils {
+
+class string_view {
+ public:
+ explicit string_view(const std::string& s)
+ : str_(s.data()), len_(s.length()) {}
+ string_view(const char* str, int len) : str_(str), len_(len) {}
+ // A constructor from c string.
+ explicit string_view(const char* s) : str_(s), len_(strlen(s)) {}
+
+ int length() const { return len_; }
+ const char* data() const { return str_; }
+ bool empty() const { return len_ == 0; }
+ unsigned char at(int i) const { return str_[i]; }
+
+ private:
+ const char* str_ = nullptr;
+ const int len_ = 0;
+};
+
+inline std::ostream& operator<<(std::ostream& os, const string_view& sv) {
+ os << std::string(sv.data(), sv.length());
+ return os;
+}
+inline bool operator==(const string_view& view1, const string_view& view2) {
+ if (view1.length() != view2.length()) {
+ return false;
+ }
+ return memcmp(view1.data(), view2.data(), view1.length()) == 0;
+}
+
+} // namespace utils
+} // namespace sentencepiece
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_UTILS_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_flex_delegate.tflite b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_flex_delegate.tflite
new file mode 100644
index 00000000..dc3b78b2
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_flex_delegate.tflite
Binary files differ
diff --git a/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_1d_input.tflite b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_1d_input.tflite
new file mode 100644
index 00000000..03640e28
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_1d_input.tflite
Binary files differ
diff --git a/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_2d_input.tflite b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_2d_input.tflite
new file mode 100644
index 00000000..b6883745
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_2d_input.tflite
Binary files differ
diff --git a/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_tensor.tflite b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_tensor.tflite
new file mode 100644
index 00000000..88e5cef5
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_tensor.tflite
Binary files differ
diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
new file mode 100644
index 00000000..dad2f000
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc
@@ -0,0 +1,224 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h"
+
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/lite/context.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/string_util.h"
+#include "libutf/utf.h"
+
+constexpr int kInput = 0;
+constexpr int kOutputValues = 0;
+constexpr int kOutputRowSplitsStart = 1;
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace whitespace_tokenizer {
+
+// This TFLite op implements a whitespace tokenizer, and can output the
+// tokens as either a padded tensor or a ragged tensor.
+//
+// If we're outputting a padded tensor, our outputs are:
+// * A string tensor
+//
+// If we're outputting a ragged tensor, our outputs are:
+// * A string tensor (the innermost values of the ragged tensor)
+// * N int64 tensors (the row_splits of the ragged tensor, where N is the
+// rank of the input tensor)
+
+inline bool OutputIsPaddedTensor(TfLiteNode* node) {
+ return NumOutputs(node) == 1;
+}
+
+inline int charntorune(Rune* r, const char* s, int n) {
+ const int bytes_read = chartorune(r, const_cast<char *>(s));
+ if (bytes_read > n) {
+ *r = Runeerror;
+ return 0;
+ }
+ return bytes_read;
+}
+
+std::vector<std::pair<const char*, int>> Tokenize(StringRef str) {
+ const char* p = str.str;
+ int n = str.len;
+
+ std::vector<std::pair<const char*, int>> tokens;
+ const char* start = nullptr;
+ while (n > 0) {
+ Rune r;
+ int c = charntorune(&r, p, n);
+ if (r == Runeerror) break;
+
+ if (isspacerune(r)) {
+ if (start != nullptr) {
+ tokens.push_back({start, p - start});
+ }
+ start = nullptr;
+ } else {
+ if (start == nullptr) {
+ start = p;
+ }
+ }
+
+ p += c;
+ n -= c;
+ }
+ if (start != nullptr) {
+ tokens.push_back({start, p - start});
+ }
+
+ return tokens;
+}
+
+TfLiteStatus WritePaddedOutput(
+ const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens,
+ const TfLiteTensor* input, TfLiteTensor* output_values) {
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) + 1);
+ for (int i = 0; i < NumDimensions(input); ++i) {
+ output_shape->data[i] = SizeOfDimension(input, i);
+ }
+
+ size_t max_tokens = 0;
+ for (const auto& tokens : list_of_tokens) {
+ max_tokens = std::max(max_tokens, tokens.size());
+ }
+
+ output_shape->data[NumDimensions(input)] = max_tokens;
+ DynamicBuffer buffer;
+ for (const auto& tokens : list_of_tokens) {
+ for (const auto& token : tokens) {
+ buffer.AddString(token.first, token.second);
+ }
+ for (int i = tokens.size(); i < max_tokens; ++i) {
+ buffer.AddString(nullptr, 0);
+ }
+ }
+ buffer.WriteToTensor(output_values, output_shape);
+ return kTfLiteOk;
+}
+
+TfLiteStatus WriteRaggedOutput(
+ const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens,
+ const TfLiteTensor* input, TfLiteTensor* output_values,
+ std::vector<TfLiteTensor*> nested_row_splits) {
+ // The outer dimensions of the ragged tensor are all non-ragged.
+ for (int i = 0; i < nested_row_splits.size() - 1; ++i) {
+ int row_splits_step = SizeOfDimension(input, i + 1);
+ TfLiteTensor* row_splits = nested_row_splits[i];
+ for (int j = 0; j < SizeOfDimension(row_splits, 0); ++j) {
+ row_splits->data.i64[j] = j * row_splits_step;
+ }
+ }
+
+ // Generate the innermost row_splits and values tensors.
+ TfLiteTensor* row_splits = nested_row_splits.back();
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(1);
+ DynamicBuffer buffer;
+ int token_index = 0;
+ int row_splits_index = 0;
+ for (const auto& tokens : list_of_tokens) {
+ row_splits->data.i64[row_splits_index] = token_index;
+ for (const auto& token : tokens) {
+ buffer.AddString(token.first, token.second);
+ ++token_index;
+ }
+ ++row_splits_index;
+ }
+ row_splits->data.i64[row_splits_index] = token_index;
+ output_shape->data[0] = token_index;
+ buffer.WriteToTensor(output_values, output_shape);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
+ SetTensorToDynamic(output_values);
+
+ if (OutputIsPaddedTensor(node)) {
+ return kTfLiteOk;
+ }
+
+ const TfLiteTensor* input = GetInput(context, node, kInput);
+ TF_LITE_ENSURE(context, NumDimensions(input) ==
+ (NumOutputs(node) - kOutputRowSplitsStart));
+
+ // Resize the row_splits tensors. We're just adding a ragged inner
+ // dimension to the shape of the input tensor, so the size of the
+ // row_splits tensors can be calculated using the input tensor's shape.
+ int input_size = 1;
+ for (int i = 0; i < NumDimensions(input); ++i) {
+ input_size *= SizeOfDimension(input, i);
+
+ TfLiteIntArray* row_splits_shape = TfLiteIntArrayCreate(1);
+ row_splits_shape->data[0] = input_size + 1;
+ TfLiteTensor* row_splits =
+ GetOutput(context, node, kOutputRowSplitsStart + i);
+ TF_LITE_ENSURE_STATUS(
+ context->ResizeTensor(context, row_splits, row_splits_shape));
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInput);
+ int input_size = 1;
+ for (int i = 0; i < NumDimensions(input); ++i) {
+ input_size *= SizeOfDimension(input, i);
+ }
+
+ std::vector<std::vector<std::pair<const char*, int>>> list_of_tokens;
+ list_of_tokens.reserve(input_size);
+ for (int i = 0; i < input_size; ++i) {
+ list_of_tokens.emplace_back(Tokenize(GetString(input, i)));
+ }
+
+ TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
+ TF_LITE_ENSURE(context, IsDynamicTensor(output_values));
+
+ if (OutputIsPaddedTensor(node)) {
+ return WritePaddedOutput(list_of_tokens, input, output_values);
+ }
+
+ std::vector<TfLiteTensor*> nested_row_splits;
+ nested_row_splits.reserve(NumDimensions(input));
+ for (int i = 0; i < NumDimensions(input); ++i) {
+ TfLiteTensor* output_row_splits =
+ GetOutput(context, node, kOutputRowSplitsStart + i);
+ nested_row_splits.push_back(output_row_splits);
+ }
+ return WriteRaggedOutput(list_of_tokens, input, output_values,
+ nested_row_splits);
+}
+
+} // namespace whitespace_tokenizer
+
+TfLiteRegistration* Register_tftext_WhitespaceTokenizer() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ whitespace_tokenizer::Prepare,
+ whitespace_tokenizer::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h
new file mode 100644
index 00000000..b1902480
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h
@@ -0,0 +1,31 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_tftext_WhitespaceTokenizer();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
new file mode 100644
index 00000000..534fbef4
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h"
+
+#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+void AddWhitespaceTokenizerCustomOp(MutableOpResolver* resolver) {
+ resolver->AddCustom("tftext:WhitespaceTokenizer",
+ Register_tftext_WhitespaceTokenizer());
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h
new file mode 100644
index 00000000..4f57d8d8
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h
@@ -0,0 +1,34 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_OP_RESOLVER_H_
+#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_OP_RESOLVER_H_
+
+#include "tensorflow/lite/mutable_op_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+// Adds the WhitespaceTokenizer custom op to an op resolver.
+// This function can be loaded using dlopen. Since C++ function names get
+// mangled, declare this function as extern C, so its name is unchanged.
+extern "C" void AddWhitespaceTokenizerCustomOp(MutableOpResolver* resolver);
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LETENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_OP_RESOLVER_H_
diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver_wrapper.cc b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver_wrapper.cc
new file mode 100644
index 00000000..03d3ba89
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver_wrapper.cc
@@ -0,0 +1,29 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "pybind11/pybind11.h"
+#include "tensorflow/lite/mutable_op_resolver.h"
+#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h"
+
+PYBIND11_MODULE(_pywrap_whitespace_tokenizer_op_resolver, m) {
+ m.doc() = "_pywrap_whitespace_tokenizer_op_resolver";
+ m.def(
+ "AddWhitespaceTokenizerCustomOp",
+ [](uintptr_t resolver) {
+ tflite::ops::custom::AddWhitespaceTokenizerCustomOp(
+ reinterpret_cast<tflite::MutableOpResolver*>(resolver));
+ },
+ "Op registerer function for the tftext:WhitespaceTokenizer custom op.");
+}
diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.cc b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.cc
new file mode 100644
index 00000000..4654e46c
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.cc
@@ -0,0 +1,189 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h"
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace whitespace_tokenizer {
+namespace test {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+} // namespace
+
+enum OutputType { PADDED, RAGGED };
+
+class WhitespaceTokenizerModel : public SingleOpModel {
+ public:
+ WhitespaceTokenizerModel(OutputType output_type,
+ const std::vector<std::string>& input_values,
+ const std::vector<int>& input_shape)
+ : input_shape_(input_shape) {
+ input_ = AddInput(TensorType_STRING);
+ output_values_ = AddOutput(TensorType_STRING);
+ if (output_type == RAGGED) {
+ for (int i = 0; i < input_shape_.size(); ++i) {
+ output_row_splits_.push_back(AddOutput(TensorType_INT64));
+ }
+ }
+ SetCustomOp("WhitespaceTokenizer", {}, Register_tftext_WhitespaceTokenizer);
+
+ BuildInterpreter({input_shape});
+ PopulateStringTensor(input_, input_values);
+ Invoke();
+ }
+
+ std::vector<int> GetValuesTensorShape() {
+ return GetTensorShape(output_values_);
+ }
+
+ std::vector<std::string> ExtractValuesTensorVector() {
+ std::vector<std::string> r;
+ TfLiteTensor* tensor = interpreter_->tensor(output_values_);
+ int n = GetStringCount(tensor);
+ for (int i = 0; i < n; ++i) {
+ StringRef ref = GetString(tensor, i);
+ r.emplace_back(ref.str, ref.len);
+ }
+ return r;
+ }
+
+ void CheckRowSplits(const std::vector<int>& token_counts) {
+ int size = 1;
+ for (int i = 0; i < input_shape_.size(); ++i) {
+ size *= input_shape_[i];
+ EXPECT_THAT(GetTensorShape(output_row_splits_[i]), ElementsAre(size + 1))
+ << "row_splits " << i << " has the wrong shape";
+
+ std::vector<int64_t> expected_values(size + 1);
+ if (i == input_shape_.size() - 1) {
+ ASSERT_EQ(token_counts.size(), size);
+
+ int index = 0;
+ expected_values[0] = index;
+ for (int j = 0; j < size; ++j) {
+ index += token_counts[j];
+ expected_values[j + 1] = index;
+ }
+ } else {
+ for (int j = 0; j <= size; ++j) {
+ expected_values[j] = j * input_shape_[i + 1];
+ }
+ }
+ EXPECT_THAT(ExtractVector<int64_t>(output_row_splits_[i]),
+ ElementsAreArray(expected_values))
+ << "row_splits " << i << " has an incorrect value/index";
+ }
+ }
+
+ private:
+ int input_;
+ std::vector<int> input_shape_;
+ int output_values_;
+ std::vector<int> output_row_splits_;
+}; // namespace test
+
+TEST(WhitespaceTokenizerTest, SingleStringPaddedOutput) {
+ WhitespaceTokenizerModel m(PADDED, {"this is a test"}, {1});
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(1, 4));
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("this", "is", "a", "test"));
+}
+
+TEST(WhitespaceTokenizerTest, SingleStringRaggedOutput) {
+ WhitespaceTokenizerModel m(RAGGED, {"this is a test"}, {1});
+ m.CheckRowSplits({4});
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("this", "is", "a", "test"));
+}
+
+TEST(WhitespaceTokenizerTest, VectorPaddedOutput) {
+ WhitespaceTokenizerModel m(PADDED,
+ {"this is a test", //
+ "three token sentence", //
+ "many more tokens than that sentence"},
+ {3});
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3, 6));
+ EXPECT_THAT(
+ m.ExtractValuesTensorVector(),
+ ElementsAre("this", "is", "a", "test", "", "", //
+ "three", "token", "sentence", "", "", "", //
+ "many", "more", "tokens", "than", "that", "sentence"));
+}
+
+TEST(WhitespaceTokenizerTest, VectorRaggedOutput) {
+ WhitespaceTokenizerModel m(RAGGED,
+ {"this is a test", //
+ "three token sentence", //
+ "many more tokens than that sentence"},
+ {3});
+ m.CheckRowSplits({4, 3, 6});
+ EXPECT_THAT(
+ m.ExtractValuesTensorVector(),
+ ElementsAre("this", "is", "a", "test", //
+ "three", "token", "sentence", //
+ "many", "more", "tokens", "than", "that", "sentence"));
+}
+
+TEST(WhitespaceTokenizerTest, MatrixPaddedOutput) {
+ WhitespaceTokenizerModel m(PADDED,
+ {"a b c", "d e f", //
+ "g h", "i j k l", //
+ "m", "n o p q r"},
+ {3, 2});
+ EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3, 2, 5));
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("a", "b", "c", "", "", //
+ "d", "e", "f", "", "", //
+ "g", "h", "", "", "", //
+ "i", "j", "k", "l", "", //
+ "m", "", "", "", "", //
+ "n", "o", "p", "q", "r"));
+}
+
+TEST(WhitespaceTokenizerTest, MatrixRAGGEDOutput) {
+ WhitespaceTokenizerModel m(RAGGED,
+ {"a b c", "d e f", //
+ "g h", "i j k l", //
+ "m", "n o p q r"},
+ {3, 2});
+ m.CheckRowSplits({3, 3, 2, 4, 1, 5});
+ EXPECT_THAT(m.ExtractValuesTensorVector(),
+ ElementsAre("a", "b", "c", //
+ "d", "e", "f", //
+ "g", "h", //
+ "i", "j", "k", "l", //
+ "m", //
+ "n", "o", "p", "q", "r"));
+}
+
+} // namespace test
+} // namespace whitespace_tokenizer
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py
new file mode 100644
index 00000000..b6a1a67d
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py
@@ -0,0 +1,168 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Lint as: python3
+"""Tests for tensorflow_lite_support.custom_ops.kernel.whitespace_tokenizer."""
+
+import os
+import sys
+import timeit
+
+from absl import logging
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf
+import tensorflow_text as tf_text
+# pylint: disable=g-direct-tensorflow-import
+from tensorflow.lite.python import interpreter as interpreter_wrapper
+from tensorflow.python.platform import resource_loader
+
+# Force loaded shared object symbols to be globally visible. This is needed so
+# that the interpreter_wrapper, in one .so file, can see the op resolver
+# in a different .so file. Note that this may already be set by default.
+# pylint: disable=g-import-not-at-top,g-bad-import-order,unused-import
+if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'):
+ sys.setdlopenflags(sys.getdlopenflags() | os.RTLD_GLOBAL)
+from tensorflow_lite_support.custom_ops.kernel import _pywrap_whitespace_tokenizer_op_resolver
+
+TEST_CASES = [
+ ['this is a test'],
+ ['extra spaces in here'],
+ ['a four token sentence', 'a five token sentence thing.'],
+ [['a multi dimensional test case', 'a b c d', 'e f g'],
+ ['h i j', 'k l m 2 3', 'n o p'], ['q r s 0 1', 't u v', 'w x y z']],
+]
+
+INVOKES_FOR_SINGLE_OP_BENCHMARK = 1000
+INVOKES_FOR_FLEX_DELEGATE_BENCHMARK = 10
+
+
+@tf.function
+def _call_whitespace_tokenizer_to_tensor(test_case):
+ tokenizer = tf_text.WhitespaceTokenizer()
+ return tokenizer.tokenize(test_case).to_tensor()
+
+
+@tf.function
+def _call_whitespace_tokenizer_to_ragged(test_case):
+ tokenizer = tf_text.WhitespaceTokenizer()
+ return tokenizer.tokenize(test_case)
+
+
+class WhitespaceTokenizerTest(parameterized.TestCase):
+
+ @parameterized.parameters([t] for t in TEST_CASES)
+ def testToTensorEquivalence(self, test_case):
+ tf_output = _call_whitespace_tokenizer_to_tensor(test_case)
+
+ model_filename = resource_loader.get_path_to_datafile(
+ 'testdata/whitespace_tokenizer_to_tensor.tflite')
+ with open(model_filename, 'rb') as file:
+ model = file.read()
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=model,
+ custom_op_registerers=['AddWhitespaceTokenizerCustomOp'])
+
+ np_test_case = np.array(test_case, dtype=np.str)
+ interpreter.resize_tensor_input(0, np_test_case.shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ np_test_case)
+ interpreter.invoke()
+ tflite_output = interpreter.get_tensor(
+ interpreter.get_output_details()[0]['index'])
+
+ self.assertEqual(tf_output.numpy().tolist(), tflite_output.tolist())
+
+ @parameterized.parameters([t] for t in TEST_CASES)
+ def testToRaggedEquivalence(self, test_case):
+ tf_output = _call_whitespace_tokenizer_to_ragged(test_case)
+
+ np_test_case = np.array(test_case, dtype=np.str)
+ rank = len(np_test_case.shape)
+
+ model_filename = resource_loader.get_path_to_datafile(
+ 'testdata/whitespace_tokenizer_to_ragged_{}d_input.tflite'.format(rank))
+ with open(model_filename, 'rb') as file:
+ model = file.read()
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=model,
+ custom_op_registerers=['AddWhitespaceTokenizerCustomOp'])
+ interpreter.resize_tensor_input(0, np_test_case.shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ np_test_case)
+ interpreter.invoke()
+
+ # Traverse the nested row_splits/values of the ragged tensor.
+ for i in range(rank):
+ tflite_output_cur_row_splits = interpreter.get_tensor(
+ interpreter.get_output_details()[1 + i]['index'])
+ self.assertEqual(tf_output.row_splits.numpy().tolist(),
+ tflite_output_cur_row_splits.tolist())
+ tf_output = tf_output.values
+
+ tflite_output_values = interpreter.get_tensor(
+ interpreter.get_output_details()[0]['index'])
+ self.assertEqual(tf_output.numpy().tolist(), tflite_output_values.tolist())
+
+ def testSingleOpLatency(self):
+ model_filename = resource_loader.get_path_to_datafile(
+ 'testdata/whitespace_tokenizer_to_tensor.tflite')
+ with open(model_filename, 'rb') as file:
+ model = file.read()
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=model,
+ custom_op_registerers=['AddWhitespaceTokenizerCustomOp'])
+
+ latency = 0.0
+ for test_case in TEST_CASES:
+ np_test_case = np.array(test_case, dtype=np.str)
+ interpreter.resize_tensor_input(0, np_test_case.shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ np_test_case)
+ start_time = timeit.default_timer()
+ for _ in range(INVOKES_FOR_SINGLE_OP_BENCHMARK):
+ interpreter.invoke()
+ latency = latency + timeit.default_timer() - start_time
+
+ latency = latency / (INVOKES_FOR_SINGLE_OP_BENCHMARK * len(TEST_CASES))
+ logging.info('Latency: %fms', latency * 1000.0)
+
+ def testFlexDelegateLatency(self):
+ model_filename = resource_loader.get_path_to_datafile(
+ 'testdata/whitespace_tokenizer_flex_delegate.tflite')
+ with open(model_filename, 'rb') as file:
+ model = file.read()
+ interpreter = interpreter_wrapper.Interpreter(model_content=model)
+
+ latency = 0.0
+ for test_case in TEST_CASES:
+ np_test_case = np.array(test_case, dtype=np.str)
+ interpreter.resize_tensor_input(0, np_test_case.shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(interpreter.get_input_details()[0]['index'],
+ np_test_case)
+ start_time = timeit.default_timer()
+ for _ in range(INVOKES_FOR_FLEX_DELEGATE_BENCHMARK):
+ interpreter.invoke()
+ latency = latency + timeit.default_timer() - start_time
+
+ latency = latency / (INVOKES_FOR_FLEX_DELEGATE_BENCHMARK * len(TEST_CASES))
+ logging.info('Latency: %fms', latency * 1000.0)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow_lite_support/custom_ops/python/BUILD b/tensorflow_lite_support/custom_ops/python/BUILD
new file mode 100644
index 00000000..82a8a6ec
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/python/BUILD
@@ -0,0 +1,61 @@
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+py_library(
+ name = "tflite_text_api",
+ srcs = ["tflite_text_api.py"],
+ deps = [
+ # tensorflow dep,
+ # tensorflow_text dep,
+ ],
+)
+
+py_library(
+ name = "sentencepiece_tokenizer",
+ srcs = ["sentencepiece_tokenizer.py"],
+ data = [
+ "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:sentencepiece_detokenizer_op.so",
+ "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:sentencepiece_tokenizer_op.so",
+ ],
+ srcs_version = "PY3",
+ deps = [
+ # tensorflow dep,
+ "//tensorflow_lite_support/custom_ops/kernel/sentencepiece/py:pywrap_model_converter",
+ ],
+)
+
+py_test(
+ name = "sentencepiece_tokenizer_test",
+ srcs = ["sentencepiece_tokenizer_test.py"],
+ data = [
+ "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:testdata",
+ ],
+ python_version = "PY3",
+ deps = [
+ ":sentencepiece_tokenizer",
+ # tensorflow dep,
+ # tensorflow_text dep,
+ "//tensorflow_lite_support/custom_ops/kernel/sentencepiece/py:pywrap_tflite_registerer",
+ "@absl_py//absl:app",
+ "@absl_py//absl/flags",
+ "@absl_py//absl/logging",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "ragged_tensor_to_tensor_test",
+ srcs = ["ragged_tensor_to_tensor_test.py"],
+ python_version = "PY3",
+ deps = [
+ # tensorflow dep,
+ "//tensorflow_lite_support/custom_ops/kernel/ragged/py:pywrap_tflite_registerer",
+ "@absl_py//absl:app",
+ "@absl_py//absl/flags",
+ "@absl_py//absl/logging",
+ ],
+)
diff --git a/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py b/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py
new file mode 100644
index 00000000..319131e0
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py
@@ -0,0 +1,57 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for ragged_tensor_to_tensor."""
+
+import tensorflow as tf
+from tensorflow.lite.python import interpreter as interpreter_wrapper # pylint: disable=g-direct-tensorflow-import
+
+
+class RaggedTensorToTensorTest(tf.test.TestCase):
+
+ def test_ragged_to_tensor(self):
+
+ @tf.function
+ def ragged_tensor_function():
+ ragged_tensor = tf.RaggedTensor.from_row_splits(
+ values=[
+ 13, 36, 83, 131, 13, 36, 4, 3127, 152, 130, 30, 2424, 168, 1644,
+ 1524, 4, 3127, 152, 130, 30, 2424, 168, 1644, 636
+ ],
+ row_splits=[0, 0, 6, 15, 24])
+ return ragged_tensor.to_tensor()
+
+ concrete_function = ragged_tensor_function.get_concrete_function()
+
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function])
+ converter.allow_custom_ops = True
+ tflite_model = converter.convert()
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=tflite_model,
+ custom_op_registerers=["TFLite_RaggedTensorToTensorRegisterer"])
+ interpreter.allocate_tensors()
+ interpreter.invoke()
+ output_details = interpreter.get_output_details()
+ expected_result_values = [[0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [13, 36, 83, 131, 13, 36, 0, 0, 0],
+ [4, 3127, 152, 130, 30, 2424, 168, 1644, 1524],
+ [4, 3127, 152, 130, 30, 2424, 168, 1644, 636]]
+ self.assertAllEqual(
+ interpreter.get_tensor(output_details[0]["index"]),
+ expected_result_values)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py
new file mode 100644
index 00000000..21efed56
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py
@@ -0,0 +1,125 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Lint as: python3
+"""Python class that implements Sentencepiece tokenizer.
+
+It follows TF.text designers design.
+
+"""
+import tensorflow.compat.v2 as tf # pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.ops.ragged import ragged_tensor # pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.framework import load_library
+from tensorflow.python.platform import resource_loader
+gen_sentencepiece_detokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_detokenizer_op.so'))
+from tensorflow.python.framework import load_library
+from tensorflow.python.platform import resource_loader
+gen_sentencepiece_tokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_tokenizer_op.so'))
+from tensorflow_lite_support.custom_ops.kernel.sentencepiece.py import pywrap_model_converter as model_converter
+
+
+class SentencepieceTokenizer:
+ """Sentencepiece tokenizer with tf.text interface."""
+
+ def __init__(self, model, reverse=False, add_bos=False, add_eos=False):
+ converted_model = model_converter.convert_sentencepiece_model(model)
+ converted_model_detokenizer = model_converter.convert_sentencepiece_model_for_decoder(
+ model)
+ # Use uint8 tensor as a buffer for the model to avoid any possible changes,
+ # for example truncation by '\0'.
+ self._converted_model = tf.constant(list(converted_model), dtype=tf.uint8)
+ self._converted_model_detokenizer = tf.constant(
+ list(converted_model_detokenizer), dtype=tf.uint8)
+ self._vocab_size = model_converter.get_vocabulary_size(converted_model)
+ self._reverse = reverse
+ self._add_bos = add_bos
+ self._add_eos = add_eos
+
+ def tokenize(self, inputs):
+ """The main tokenization function."""
+ input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(inputs)
+ if input_tensor.shape.ndims is None:
+ raise ValueError("Rank of input_tensor must be statically known.")
+ if ragged_tensor.is_ragged(input_tensor):
+ # Ensure that input has row_split_dtype is int32
+ input_tensor = input_tensor.with_row_splits_dtype(tf.int32)
+ # Recursively process the values of the ragged tensor.
+ tokens = self.tokenize(input_tensor.flat_values)
+ return input_tensor.with_flat_values(tokens)
+ else:
+ if input_tensor.shape.ndims > 1:
+ # Convert the input tensor to ragged and process it.
+ return self.tokenize(
+ tf.RaggedTensor.from_tensor(
+ input_tensor, row_splits_dtype=tf.int32))
+ elif input_tensor.shape.ndims == 0:
+ tokens = self.tokenize(tf.stack([input_tensor]))
+ return tokens.values
+ else:
+ # Our rank 1 tensor is the correct shape, so we can process it as
+ # normal.
+ (output_values, row_splits) = (
+ gen_sentencepiece_tokenizer_op.tf_sentencepiece_tokenize_op(
+ self._converted_model, input_tensor, 0, 0, self._add_bos,
+ self._add_eos, self._reverse))
+ tokens = tf.RaggedTensor.from_nested_row_splits(
+ flat_values=output_values,
+ nested_row_splits=[row_splits],
+ validate=False)
+ return tokens
+
+ def detokenize(self, input): # pylint: disable=redefined-builtin
+ """Detokenizes tokens into preprocessed text.
+
+ Args:
+ input: A `RaggedTensor` or `Tensor` with int32 encoded text with rank >=
+ 1.
+
+ Returns:
+ A N-1 dimensional string Tensor or RaggedTensor of the detokenized text.
+ """
+ input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
+ if input_tensor.shape.ndims is None:
+ raise ValueError("Rank of input_tensor must be statically known.")
+ if input_tensor.shape.ndims == 0:
+ raise ValueError("Rank of input_tensor must be at least 1.")
+ if ragged_tensor.is_ragged(input_tensor):
+ if input_tensor.flat_values.shape.ndims > 1:
+ # If the flat_values of our ragged tensor is multi-dimensional, we can
+ # process it separately and our output will have the same nested
+ # splits as our input.
+ tokens = self.detokenize(input_tensor.flat_values)
+ return input_tensor.with_flat_values(tokens)
+ elif input_tensor.ragged_rank > 1:
+ # Recursively process the values of the ragged tensor.
+ tokens = self.detokenize(input_tensor.values)
+ return input_tensor.with_values(tokens)
+ else:
+ return gen_sentencepiece_detokenizer_op.tf_sentencepiece_detokenize_op(
+ self._converted_model_detokenizer, input_tensor.flat_values,
+ input_tensor.row_splits)
+ else:
+ if input_tensor.shape.ndims > 1:
+ # Convert the input tensor to ragged and process it.
+ return self.detokenize(
+ tf.RaggedTensor.from_tensor(
+ input_tensor, row_splits_dtype=tf.int32))
+ else:
+ tokens = self.detokenize(tf.stack([input_tensor]))
+ return tf.reshape(tokens, [])
+
+ def vocab_size(self):
+ """Returns size of the vocabulary in Sentencepiece model."""
+ return self._vocab_size
diff --git a/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py
new file mode 100644
index 00000000..3609b469
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py
@@ -0,0 +1,251 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Lint as: python3
+"""Tests for sentencepiece_tokenizer."""
+
+import os
+import sys
+import time
+
+from absl import flags
+import numpy as np
+import tensorflow.compat.v2 as tf # pylint: disable=g-direct-tensorflow-import
+import tensorflow_text
+# Force loaded shared object symbols to be globally visible. This is needed so
+# that the interpreter_wrapper, in one .so file, can see the op resolver
+# in a different .so file. Note that this may already be set by default.
+# pylint: disable=g-import-not-at-top,g-bad-import-order,unused-import
+if hasattr(sys, "setdlopenflags") and hasattr(sys, "getdlopenflags"):
+ sys.setdlopenflags(sys.getdlopenflags() | os.RTLD_GLOBAL)
+from tensorflow.lite.python import interpreter as interpreter_wrapper # pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.platform import resource_loader
+from tensorflow_lite_support.custom_ops.python import sentencepiece_tokenizer
+from tensorflow_lite_support.custom_ops.kernel.sentencepiece.py import pywrap_tflite_registerer
+
+FLAGS = flags.FLAGS
+
+SENTENCEPIECE_MODEL_FILE = (
+ "../kernel/sentencepiece/testdata/sentencepiece.model")
+
+
+def _GetSentencepieceModel():
+ model_filename = resource_loader.get_path_to_datafile(
+ SENTENCEPIECE_MODEL_FILE)
+ with open(model_filename, "rb") as file:
+ model = file.read()
+ return model
+
+
+class SentencepieceTokenizerTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(SentencepieceTokenizerTest, self).setUp()
+ self.sentencepiece_model = _GetSentencepieceModel()
+
+ def test_tftext_sentencepiece_tokenizer(self):
+ """Check that the new tokenizer produces the same result that the tftext one."""
+ tftext_sp = tensorflow_text.SentencepieceTokenizer(self.sentencepiece_model)
+ opt_sp = sentencepiece_tokenizer.SentencepieceTokenizer(
+ self.sentencepiece_model)
+
+ input_text = [
+ u" ", u"to be or not to be", u"ignored by length text1",
+ u"ignored by length text2"
+ ]
+ tftext_tokenized = tftext_sp.tokenize(input_text)
+ opt_tokenized = opt_sp.tokenize(input_text)
+ self.assertAllEqual(tftext_tokenized, opt_tokenized)
+
+ def test_tftext_sentencepiece_detokenizer(self):
+ """Check that the new tokenizer produces the same result that the tftext one."""
+ tftext_sp = tensorflow_text.SentencepieceTokenizer(self.sentencepiece_model)
+ opt_sp = sentencepiece_tokenizer.SentencepieceTokenizer(
+ self.sentencepiece_model)
+
+ input_text = [
+ u" ", u"to be or not to be", u"ignored by length text1",
+ u"ignored by length text2"
+ ]
+ tftext_tokenized = tftext_sp.tokenize(input_text)
+
+ # Check detokenizer
+ tftext_detokenized = tftext_sp.detokenize(tftext_tokenized)
+ opt_detokenized = opt_sp.detokenize(tftext_tokenized)
+ self.assertAllEqual(tftext_detokenized, opt_detokenized)
+
+ def test_tftext_sentencepiece_tokenizer_bos_eos(self):
+ """Check that the new tokenizer produces the same result that the tftext one with bos and eos."""
+ tftext_sp = tensorflow_text.SentencepieceTokenizer(
+ self.sentencepiece_model, add_bos=True, add_eos=True)
+ opt_sp = sentencepiece_tokenizer.SentencepieceTokenizer(
+ self.sentencepiece_model, add_bos=True, add_eos=True)
+
+ input_text = [
+ u" ", u"to be or not to be", u"ignored by length text1",
+ u"ignored by length text2"
+ ]
+ tftext_tokenized = tftext_sp.tokenize(input_text)
+ opt_tokenized = opt_sp.tokenize(input_text)
+ self.assertAllEqual(tftext_tokenized, opt_tokenized)
+
+ def test_tflite_opt_sentence_tokenizer(self):
+ """Check that can convert a Keras model to TFLite and it produces the same result for tokenization."""
+
+ class TokenizerLayer(tf.keras.layers.Layer):
+
+ def __init__(self, sentencepiece_model, **kwargs):
+ super(TokenizerLayer, self).__init__(**kwargs)
+ self.sp = sentencepiece_tokenizer.SentencepieceTokenizer(
+ sentencepiece_model)
+
+ def call(self, input_tensor, **kwargs):
+ return self.sp.tokenize(input_tensor).flat_values
+
+ model = tf.keras.models.Sequential(
+ [TokenizerLayer(self.sentencepiece_model)])
+ input_data = np.array([[
+ u" ", u"to be or not to be", u"ignored by length text1",
+ u"ignored by length text2"
+ ]])
+ tf_result = model.predict(input_data)
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+ converter.target_spec.supported_ops = supported_ops
+ converter.allow_custom_ops = True
+ tflite_model = converter.convert()
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=tflite_model,
+ custom_op_registerers=["TFLite_SentencepieceTokenizerRegisterer"])
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+
+ interpreter.set_tensor(input_details[0]["index"], input_data)
+ interpreter.invoke()
+ output_details = interpreter.get_output_details()
+ expected_result = [
+ 13, 36, 83, 131, 13, 36, 4, 3127, 152, 130, 30, 2424, 168, 1644, 1524,
+ 4, 3127, 152, 130, 30, 2424, 168, 1644, 636
+ ]
+ self.assertAllEqual(tf_result, expected_result)
+ self.assertAllEqual(
+ interpreter.get_tensor(output_details[0]["index"]), expected_result)
+
+ def test_tflite_opt_sentence_detokenizer(self):
+ """Check that can convert a Keras model to TFLite and it produces the same result for tokenization."""
+
+ class DeTokenizerLayer(tf.keras.layers.Layer):
+
+ def __init__(self, sentencepiece_model, **kwargs):
+ super(DeTokenizerLayer, self).__init__(**kwargs)
+ self.sp = sentencepiece_tokenizer.SentencepieceTokenizer(
+ sentencepiece_model)
+
+ def call(self, input_tensor, **kwargs):
+ return self.sp.detokenize(input_tensor)
+
+ model = tf.keras.models.Sequential(
+ [DeTokenizerLayer(self.sentencepiece_model)])
+ input_data = np.array([[
+ 13, 36, 83, 131, 13, 36, 4, 3127, 152, 130, 30, 2424, 168, 1644, 1524,
+ 4, 3127, 152, 130, 30, 2424, 168, 1644, 636
+ ]],
+ dtype=np.int32)
+ tf_result = model.predict(input_data)
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+ converter.target_spec.supported_ops = supported_ops
+ converter.allow_custom_ops = True
+ tflite_model = converter.convert()
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=tflite_model,
+ custom_op_registerers=["TFLite_SentencepieceTokenizerRegisterer"])
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+
+ interpreter.set_tensor(input_details[0]["index"], input_data)
+ interpreter.invoke()
+ output_details = interpreter.get_output_details()
+ expected_result = [
+ "to be or not to be ignored by length text1 ignored by length text2"
+ ]
+ self.assertAllEqual(tf_result, expected_result)
+ self.assertAllEqual(
+ interpreter.get_tensor(output_details[0]["index"]), expected_result)
+
+ def test_tflite_opt_sentence_tokenizer_vocab_size(self):
+ """Check that can convert a Keras model to TFLite and it produces the same result for vocabulary size."""
+
+ class TokenizerLayer(tf.keras.layers.Layer):
+
+ def __init__(self, sentencepiece_model, **kwargs):
+ super(TokenizerLayer, self).__init__(**kwargs)
+ self.sp = sentencepiece_tokenizer.SentencepieceTokenizer(
+ sentencepiece_model)
+
+ def call(self, input_tensor, **kwargs):
+ return self.sp.vocab_size()
+
+ model = tf.keras.models.Sequential(
+ [TokenizerLayer(self.sentencepiece_model)])
+ input_data = np.array([[""]])
+ tf_result = model.predict(input_data)
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+ converter.target_spec.supported_ops = supported_ops
+ converter.allow_custom_ops = True
+ tflite_model = converter.convert()
+ interpreter = interpreter_wrapper.InterpreterWithCustomOps(
+ model_content=tflite_model,
+ custom_op_registerers=["TFLite_SentencepieceTokenizerRegisterer"])
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+ interpreter.set_tensor(input_details[0]["index"], input_data)
+ interpreter.invoke()
+ output_details = interpreter.get_output_details()
+ expected_result = 4000
+ self.assertEqual(tf_result, expected_result)
+ self.assertAllEqual(
+ interpreter.get_tensor(output_details[0]["index"]), expected_result)
+
+
+class SentencepieceTokenizerBenchmark(tf.test.Benchmark):
+
+ def benchmarkTokenizer(self):
+ sp_model = _GetSentencepieceModel()
+ test_text = [
+ "This week we celebrate the casts and creatives who have come together"
+ " to bring us our favorite.",
+ "More Stacks products demonstrated commitment to excellent support.",
+ "Test, test, test."
+ ]
+
+ tftext_sp = tensorflow_text.SentencepieceTokenizer(sp_model)
+ opt_sp = sentencepiece_tokenizer.SentencepieceTokenizer(sp_model)
+ iter_number = 1000
+ start = time.time()
+ for _ in range(iter_number):
+ _ = opt_sp.tokenize(test_text)
+ self.report_benchmark(
+ iters=iter_number, wall_time=time.time() - start, name="opt")
+ start = time.time()
+ for _ in range(iter_number):
+ _ = tftext_sp.tokenize(test_text)
+ self.report_benchmark(
+ iters=iter_number, wall_time=time.time() - start, name="tf.text")
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow_lite_support/custom_ops/python/tflite_text_api.py b/tensorflow_lite_support/custom_ops/python/tflite_text_api.py
new file mode 100644
index 00000000..1466df29
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/python/tflite_text_api.py
@@ -0,0 +1,126 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Wrapped TF.Text friendly to Tensorflow Lite conversion."""
+
+import tensorflow as tf
+import tensorflow_text as tf_text
+
+
+class WhitespaceTokenizer(tf_text.Tokenizer):
+ """TFLite friendly API for tensorflow_text.WhitspaceTokenizer.tokenize.
+
+ The strings are split on ICU defined whitespace characters. These
+ whitespace characters are dropped. See more details in
+ https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WhitespaceTokenizer.md
+
+ Does not currently support tokenize_with_offsets().
+ """
+
+ def __init__(self):
+ super(WhitespaceTokenizer, self).__init__()
+ self._tokenizer = tf_text.WhitespaceTokenizer()
+
+ def tokenize(self, input_tensor):
+ """Tokenize input strings.
+
+ Args:
+ input_tensor: A `Tensor` of UTF-8 strings with rank 0, 1 or 2.
+
+ Returns:
+ A `RaggedTensor` of tokenized text. The returned shape is the shape of the
+ input tensor with an added ragged dimension for tokens of each string.
+ """
+
+ @tf.function(experimental_implements='name: "tftext:WhitespaceTokenizer"')
+ def func(input_tensor):
+ return self._tokenizer.tokenize(input_tensor)
+
+ return func(input_tensor)
+
+
+def ngrams(data,
+ width,
+ axis=-1,
+ reduction_type=None,
+ string_separator=' ',
+ name=None):
+ """TFLite friendly API for tensorflow_text.ngrams.
+
+ Creates a tensor of n-grams based data, a token tensor. See more details in
+ https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/ngrams.md
+
+ Args:
+ data: The data to reduce. Must be convertible into a tf.Tensor or a
+ tf.RaggedTensor (in which case it will be deconstructed into its component
+ tf.Tensors).
+ width: The width of the ngram window. If there is not sufficient data to
+ fill out the ngram window, the resulting ngram will be empty.
+ axis: The axis to create ngrams along. Note that for string join reductions,
+ only axis '-1' is supported; for other reductions, any positive or
+ negative axis can be used. Should be a constant.
+ reduction_type: A member of the Reduction enum. Should be a constant.
+ Currently supports:
+ * `Reduction.STRING_JOIN`: Join strings in the window. Note that axis must
+ be -1 here.
+ string_separator: The separator string used for `Reduction.STRING_JOIN`.
+ Ignored otherwise. Must be a string constant, not a Tensor.
+ name: The op name.
+
+ Returns:
+ A tensor of ngrams. If `data` is a ragged tensor, this will be a ragged
+ tensor. Otherwise it will be a plain tensor.
+ """
+
+ if reduction_type is not tf_text.Reduction.STRING_JOIN:
+ # TODO(b/162082752): Provide support for Reduction.SUM and Reduction.MEAN
+ raise tf.errors.InvalidArgumentError(
+ None, None, 'only Reduction.STRING_JOIN is currently supported')
+
+ if reduction_type is tf_text.Reduction.STRING_JOIN and axis != -1:
+ raise tf.errors.InvalidArgumentError(
+ None, None, 'For Reduction.STRING_JOIN, axis must be -1')
+
+ experimental_implements = [
+ 'name: "tftext:Ngrams"',
+ 'attr { key: "width" value { i: %d } }' % width,
+ 'attr { key: "axis" value { i: %d } }' % axis,
+ 'attr { key: "reduction_type" value { s: "STRING_JOIN" } }',
+ 'attr { key: "string_separator" value { s: "%s" } }' % string_separator,
+ ]
+ experimental_implements = ' '.join(experimental_implements)
+
+ if isinstance(data, tf.RaggedTensor):
+
+ # Since `data` can not be converted directly into a Tensor, we define
+ # ragged_func() which takes a deconstructed tf.RaggedTensor
+ # (one flat_values tensor and N row_splits tensors), pass it the
+ # deconstructed version of `data`, and then immediately reconstruct it
+ # within ragged_func().
+ @tf.function(experimental_implements=experimental_implements)
+ def ragged_func(values, *args):
+ ragged_tensor = tf.RaggedTensor.from_nested_row_splits(
+ flat_values=values, nested_row_splits=args)
+ return tf_text.ngrams(ragged_tensor, width, axis, reduction_type,
+ string_separator, name)
+
+ return ragged_func(data.flat_values, *data.nested_row_splits)
+
+ @tf.function(experimental_implements=experimental_implements)
+ def func(data):
+ return tf_text.ngrams(data, width, axis, reduction_type, string_separator,
+ name)
+
+ return func(data)
diff --git a/tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflite b/tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflite
new file mode 100644
index 00000000..e841b964
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflite
Binary files differ
diff --git a/tensorflow_lite_support/custom_ops/tf_configure.sh b/tensorflow_lite_support/custom_ops/tf_configure.sh
new file mode 100644
index 00000000..dbc96da7
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/tf_configure.sh
@@ -0,0 +1,60 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+function write_action_env_to_bazelrc() {
+ echo "build --action_env $1=\"$2\"" >> .bazelrc
+}
+
+function is_linux() {
+ [[ "${PLATFORM}" == "linux" ]]
+}
+
+function is_macos() {
+ [[ "${PLATFORM}" == "darwin" ]]
+}
+
+function is_windows() {
+ # On windows, the shell script is actually running in msys
+ [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]]
+}
+
+TF_CFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
+TF_LFLAGS="$(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')"
+HEADER_DIR=${TF_CFLAGS:2}
+if is_windows; then
+ SHARED_LIBRARY_DIR=${SHARED_LIBRARY_DIR//\\//}
+ SHARED_LIBRARY_NAME=${SHARED_LIBRARY_NAME//\\//}
+ HEADER_DIR=${HEADER_DIR//\\//}
+fi
+if is_windows; then
+ # Use pywrap_tensorflow instead of tensorflow_framework on Windows
+ SHARED_LIBRARY_DIR=${TF_CFLAGS:2:-7}"python"
+else
+ SHARED_LIBRARY_DIR=${TF_LFLAGS:2}
+fi
+SHARED_LIBRARY_NAME=$(echo $TF_LFLAGS | rev | cut -d":" -f1 | rev)
+if ! [[ $TF_LFLAGS =~ .*:.* ]]; then
+ if is_macos; then
+ SHARED_LIBRARY_NAME="libtensorflow_framework.dylib"
+ elif is_windows; then
+ # Use pywrap_tensorflow's import library on Windows. It is in the same dir as the dll/pyd.
+ SHARED_LIBRARY_NAME="_pywrap_tensorflow_internal.lib"
+ else
+ SHARED_LIBRARY_NAME="libtensorflow_framework.so"
+ fi
+fi
+write_action_env_to_bazelrc "TF_HEADER_DIR" ${HEADER_DIR}
+write_action_env_to_bazelrc "TF_SHARED_LIBRARY_DIR" ${SHARED_LIBRARY_DIR}
+write_action_env_to_bazelrc "TF_SHARED_LIBRARY_NAME" ${SHARED_LIBRARY_NAME}
diff --git a/tensorflow_lite_support/custom_ops/tflite_inference_main.cc b/tensorflow_lite_support/custom_ops/tflite_inference_main.cc
new file mode 100644
index 00000000..2819deea
--- /dev/null
+++ b/tensorflow_lite_support/custom_ops/tflite_inference_main.cc
@@ -0,0 +1,105 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This program runs the tflite model specified in --model with random inputs.
+// For string type, the input is filled with a fixed string.
+
+#include <string>
+
+#include <glog/logging.h>
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/model_builder.h"
+#include "tensorflow/lite/string_util.h"
+#include "tensorflow/lite/tools/command_line_flags.h"
+
+void FillRandomString(tflite::DynamicBuffer* buffer,
+ const TfLiteIntArray* dim_array,
+ const std::function<std::string()>& random_func) {
+ int num_elements = 1;
+ for (size_t i = 0; i < dim_array->size; i++) {
+ num_elements *= dim_array->data[i];
+ }
+ for (int i = 0; i < num_elements; ++i) {
+ auto str = random_func();
+ buffer->AddString(str.data(), str.length());
+ }
+}
+
+void RunWithRandomInputs(const std::string& filename) {
+ std::unique_ptr<tflite::FlatBufferModel> model =
+ tflite::FlatBufferModel::BuildFromFile(filename.c_str());
+
+ // Build the interpreter
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
+ LOG(FATAL) << "Could not initialize interpreter for TFLite model.";
+ }
+
+ // Resize input tensors, if desired.
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ LOG(FATAL) << "Could not allocate tensor.";
+ }
+
+ // Fill the random data.
+ std::vector<std::vector<uint8_t>> sample;
+ for (int tensor_idx : interpreter->inputs()) {
+ auto tensor = interpreter->tensor(tensor_idx);
+ if (tensor->type == kTfLiteString) {
+ tflite::DynamicBuffer buffer;
+ FillRandomString(&buffer, tensor->dims, []() {
+ return "we're have some friends over saturday to hang out in the "
+ "yard";
+ });
+ buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
+ } else {
+ std::vector<uint8_t> data(tensor->bytes);
+ for (auto it = data.begin(); it != data.end(); ++it) {
+ *it = random();
+ }
+ sample.push_back(data);
+ tensor->data.raw = reinterpret_cast<char*>(sample.rbegin()->data());
+ }
+ }
+
+ // Running inference.
+ if (interpreter->Invoke() != kTfLiteOk) {
+ LOG(FATAL) << "Failed to run the model.";
+ }
+
+ // Get the output.
+ for (int tensor_idx : interpreter->outputs()) {
+ auto tensor = interpreter->tensor(tensor_idx);
+ LOG(INFO) << "Output type: " << TfLiteTypeGetName(tensor->type);
+ }
+}
+
+int main(int argc, char** argv) {
+ // Parse flags to get the filename.
+ std::string filename;
+ std::vector<tflite::Flag> flag_list{tflite::Flag::CreateFlag(
+ "model", &filename, "The tflite model to run sample inference.",
+ tflite::Flag::kRequired)};
+ tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ // Run the model with random inputs.
+ RunWithRandomInputs(filename);
+ return 0;
+}
diff --git a/tensorflow_lite_support/examples/task/text/desktop/BUILD b/tensorflow_lite_support/examples/task/text/desktop/BUILD
new file mode 100644
index 00000000..067d59cb
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/text/desktop/BUILD
@@ -0,0 +1,68 @@
+package(
+ default_visibility = [
+ "//tensorflow_lite_support:users",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# Example usage:
+# bazel run -c opt \
+# tensorflow_lite_support/examples/task/text/desktop:bert_question_answerer_demo \
+# -- \
+# --model_path=/path/to/model.tflite \
+# --question="question to ask" \
+# --context="context for the question to ask"
+cc_binary(
+ name = "bert_question_answerer_demo",
+ srcs = ["bert_question_answerer_demo.cc"],
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+# Example usage:
+# bazel run -c opt \
+# tensorflow_lite_support/examples/task/text/desktop:bert_nl_classifier_demo \
+# -- \
+# --model_path=/path/to/model.tflite \
+# --text="text to classify"
+cc_binary(
+ name = "bert_nl_classifier_demo",
+ srcs = ["bert_nl_classifier_demo.cc"],
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:category",
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+# Example usage:
+# bazel run -c opt \
+# tensorflow_lite_support/examples/task/text/desktop:nl_classifier_demo \
+# -- \
+# --model_path=/path/to/model.tflite \
+# --text="text to classify" \
+# --input_tensor_name="input_text" \
+# --output_score_tensor_name="probability"
+cc_binary(
+ name = "nl_classifier_demo",
+ srcs = ["nl_classifier_demo.cc"],
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:category",
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
diff --git a/tensorflow_lite_support/examples/task/text/desktop/README.md b/tensorflow_lite_support/examples/task/text/desktop/README.md
new file mode 100644
index 00000000..859504fc
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/text/desktop/README.md
@@ -0,0 +1,134 @@
+# CLI Demos for C++ Text Task APIs
+
+This folder contains simple command-line tools for easily trying out the C++
+Text Task APIs.
+
+## Bert Question Answerer
+
+#### Prerequisites
+
+You will need:
+
+* a TFLite bert based question answerer model from model maker.
+(e.g. [mobilebert][1] or [albert][2] available on TensorFlow Hub).
+
+#### Usage
+
+In the console, run:
+
+```bash
+# Download the model:
+curl \
+ -L 'https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1?lite-format=tflite' \
+ -o /tmp/mobilebert.tflite
+
+# Run the classification tool:
+bazel run -c opt \
+ tensorflow_lite_support/examples/task/text/desktop:bert_question_answerer_demo -- \
+ --model_path=/tmp/mobilebert.tflite \
+ --question="Where is Amazon rainforest?" \
+ --context="The Amazon rainforest, alternatively, the Amazon Jungle, also known in \
+English as Amazonia, is a moist broadleaf tropical rainforest in the Amazon \
+biome that covers most of the Amazon basin of South America. This basin \
+encompasses 7,000,000 km2 (2,700,000 sq mi), of which \
+5,500,000 km2 (2,100,000 sq mi) are covered by the rainforest. This region \
+includes territory belonging to nine nations."
+```
+
+#### Results
+
+In the console, you should get:
+
+```
+answer[0]: 'South America.'
+logit: 1.84847, start_index: 39, end_index: 40
+answer[1]: 'most of the Amazon basin of South America.'
+logit: 1.2921, start_index: 34, end_index: 40
+answer[2]: 'the Amazon basin of South America.'
+logit: -0.0959535, start_index: 36, end_index: 40
+answer[3]: 'the Amazon biome that covers most of the Amazon basin of South America.'
+logit: -0.498558, start_index: 28, end_index: 40
+answer[4]: 'Amazon basin of South America.'
+logit: -0.774266, start_index: 37, end_index: 40
+```
+
+## NLClassifier
+
+#### Prerequisites
+
+You will need:
+
+* a TFLite text classification model with certain format.
+(e.g. [movie_review_model][3], a model to classify movie reviews), you'll need
+to configure the input tensor and out tensor for the API, see the [doc][4] for
+details.
+
+#### Usage
+
+In the console, run:
+
+```bash
+# Download the model:
+curl \
+ -L 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite' \
+ -o /tmp/movie_review.tflite
+
+# Run the detection tool:
+bazel run -c opt \
+ tensorflow_lite_support/examples/task/text/desktop:nl_classifier_demo -- \
+ --model_path=/tmp/movie_review.tflite \
+ --text="What a waste of my time." \
+ --input_tensor_name="input_text" \
+ --output_score_tensor_name="probability"
+```
+
+#### Results
+
+In the console, you should get:
+
+```
+category[0]: 'Negative' : '0.81313'
+category[1]: 'Positive' : '0.18687'
+```
+
+## BertNLClassifier
+
+#### Prerequisites
+
+TODO(b/163086702): Update the links to models with metadata attached.
+
+You will need:
+
+* a Bert based TFLite text classification model from model maker. (e.g. [movie_review_model][5] available on TensorFlow Hub).
+
+#### Usage
+
+In the console, run:
+
+```bash
+# Download the model:
+curl \
+ -L 'https://url/to/bert/nl/classifier' \
+ -o /tmp/bert_movie_review.tflite
+
+# Run the segmentation tool:
+bazel run -c opt \
+ tensorflow_lite_support/examples/task/text/desktop:bert_nl_classifier_demo -- \
+ --model_path=/tmp/bert_movie_review.tflite \
+ --text="it's a charming and often affecting journey"
+```
+
+#### Results
+
+In the console, you should get:
+
+```
+category[0]: 'negative' : '0.00006'
+category[1]: 'positive' : '0.99994'
+```
+
+[1]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
+[2]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
+[3]: https://www.tensorflow.org/lite/models/text_classification/overview
+[4]: https://github.com/tensorflow/tflite-support/blob/fe8b69002f5416900285dc69e2baa078c91bd994/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h#L55
+[5]: http://bert/nl/classifier/model
diff --git a/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc b/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
new file mode 100644
index 00000000..15ea3bff
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc
@@ -0,0 +1,77 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iostream>
+#include <limits>
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/category.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
+
+ABSL_FLAG(std::string, model_path, "",
+ "Absolute path to the '.tflite' bert classification model.");
+ABSL_FLAG(std::string, text, "", "Text to classify.");
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace nlclassifier {
+
+absl::Status Classify() {
+ ASSIGN_OR_RETURN(
+ std::unique_ptr<BertNLClassifier> classifier,
+ BertNLClassifier::CreateFromFile(absl::GetFlag(FLAGS_model_path)));
+
+ std::vector<core::Category> categories =
+ classifier->Classify(absl::GetFlag(FLAGS_text));
+
+ for (int i = 0; i < categories.size(); ++i) {
+ const core::Category& category = categories[i];
+ std::cout << absl::StrFormat("category[%d]: '%s' : '%.5f'\n", i,
+ category.class_name, category.score);
+ }
+
+ return absl::OkStatus();
+}
+
+} // namespace nlclassifier
+} // namespace text
+} // namespace task
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // Parse command line arguments and perform sanity checks.
+ absl::ParseCommandLine(argc, argv);
+ if (absl::GetFlag(FLAGS_model_path).empty()) {
+ std::cerr << "Missing mandatory 'model_path' argument.\n";
+ return 1;
+ }
+ if (absl::GetFlag(FLAGS_text).empty()) {
+ std::cerr << "Missing mandatory 'question' argument.\n";
+ return 1;
+ }
+
+ // Run classification.
+ absl::Status status = tflite::task::text::nlclassifier::Classify();
+ if (status.ok()) {
+ return 0;
+ } else {
+ std::cerr << "Classification failed: " << status.message() << "\n";
+ return 1;
+ }
+}
diff --git a/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc b/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
new file mode 100644
index 00000000..743db71d
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc
@@ -0,0 +1,81 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iostream>
+#include <limits>
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
+
+ABSL_FLAG(std::string, model_path, "",
+ "Absolute path to the '.tflite' bert question answerer model.");
+ABSL_FLAG(std::string, question, "", "Question to ask.");
+ABSL_FLAG(std::string, context, "",
+ "Context the asked question is based upon.");
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace qa {
+
+absl::Status Answer() {
+ ASSIGN_OR_RETURN(
+ std::unique_ptr<QuestionAnswerer> answerer,
+ BertQuestionAnswerer::CreateFromFile(absl::GetFlag(FLAGS_model_path)));
+
+ std::vector<QaAnswer> answers = answerer->Answer(
+ absl::GetFlag(FLAGS_context), absl::GetFlag(FLAGS_question));
+ for (int i = 0; i < answers.size(); ++i) {
+ const QaAnswer& answer = answers[i];
+ std::cout << absl::StrFormat(
+ "answer[%d]: '%s'\n logit: '%.5f, start_index: %d, end_index: %d\n",
+ i, answer.text, answer.pos.logit, answer.pos.start, answer.pos.end);
+ }
+
+ return absl::OkStatus();
+}
+
+} // namespace qa
+} // namespace text
+} // namespace task
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // Parse command line arguments and perform sanity checks.
+ absl::ParseCommandLine(argc, argv);
+ if (absl::GetFlag(FLAGS_model_path).empty()) {
+ std::cerr << "Missing mandatory 'model_path' argument.\n";
+ return 1;
+ }
+ if (absl::GetFlag(FLAGS_question).empty()) {
+ std::cerr << "Missing mandatory 'question' argument.\n";
+ return 1;
+ }
+ if (absl::GetFlag(FLAGS_context).empty()) {
+ std::cerr << "Missing mandatory 'context' argument.\n";
+ return 1;
+ }
+ // Run the answerer.
+ absl::Status status = tflite::task::text::qa::Answer();
+ if (status.ok()) {
+ return 0;
+ } else {
+ std::cerr << "Answer failed: " << status.message() << "\n";
+ return 1;
+ }
+}
diff --git a/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc b/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
new file mode 100644
index 00000000..2e96ec63
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc
@@ -0,0 +1,112 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iostream>
+#include <limits>
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/category.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
+
+ABSL_FLAG(std::string, model_path, "",
+ "Absolute path to the '.tflite' classification model.");
+ABSL_FLAG(std::string, text, "", "Text to classify.");
+ABSL_FLAG(int, input_tensor_index, -1, "Input tensor index of the model.");
+ABSL_FLAG(int, output_score_tensor_index, -1,
+ "Output score tensor index of the model.");
+ABSL_FLAG(int, output_label_tensor_index, -1,
+ "Output label tensor index of the model.");
+ABSL_FLAG(std::string, input_tensor_name, "",
+ "Input tensor name of the model.");
+ABSL_FLAG(std::string, output_score_tensor_name, "",
+ "Output score tensor name of the model.");
+ABSL_FLAG(std::string, output_label_tensor_name, "",
+ "Output label tensor name of the model.");
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace nlclassifier {
+
+absl::Status Classify() {
+ NLClassifierOptions options{};
+ if (absl::GetFlag(FLAGS_input_tensor_index) >= 0) {
+ options.input_tensor_index = absl::GetFlag(FLAGS_input_tensor_index);
+ }
+ if (absl::GetFlag(FLAGS_output_score_tensor_index) >= 0) {
+ options.output_score_tensor_index =
+ absl::GetFlag(FLAGS_output_score_tensor_index);
+ }
+ if (absl::GetFlag(FLAGS_output_label_tensor_index) >= 0) {
+ options.output_label_tensor_index =
+ absl::GetFlag(FLAGS_output_label_tensor_index);
+ }
+ if (!absl::GetFlag(FLAGS_input_tensor_name).empty()) {
+ options.input_tensor_name = absl::GetFlag(FLAGS_input_tensor_name);
+ }
+ if (!absl::GetFlag(FLAGS_output_score_tensor_name).empty()) {
+ options.output_score_tensor_name =
+ absl::GetFlag(FLAGS_output_score_tensor_name);
+ }
+ if (!absl::GetFlag(FLAGS_output_label_tensor_name).empty()) {
+ options.output_label_tensor_name =
+ absl::GetFlag(FLAGS_output_label_tensor_name);
+ }
+
+ ASSIGN_OR_RETURN(std::unique_ptr<NLClassifier> classifier,
+ NLClassifier::CreateFromFileAndOptions(
+ absl::GetFlag(FLAGS_model_path), options));
+
+ std::vector<core::Category> categories =
+ classifier->Classify(absl::GetFlag(FLAGS_text));
+
+ for (int i = 0; i < categories.size(); ++i) {
+ const core::Category& category = categories[i];
+ std::cout << absl::StrFormat("category[%d]: '%s' : '%.5f'\n", i,
+ category.class_name, category.score);
+ }
+
+ return absl::OkStatus();
+}
+
+} // namespace nlclassifier
+} // namespace text
+} // namespace task
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // Parse command line arguments and perform sanity checks.
+ absl::ParseCommandLine(argc, argv);
+ if (absl::GetFlag(FLAGS_model_path).empty()) {
+ std::cerr << "Missing mandatory 'model_path' argument.\n";
+ return 1;
+ }
+ if (absl::GetFlag(FLAGS_text).empty()) {
+ std::cerr << "Missing mandatory 'question' argument.\n";
+ return 1;
+ }
+
+ // Run classification.
+ absl::Status status = tflite::task::text::nlclassifier::Classify();
+ if (status.ok()) {
+ return 0;
+ } else {
+ std::cerr << "Classification failed: " << status.message() << "\n";
+ return 1;
+ }
+}
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/BUILD b/tensorflow_lite_support/examples/task/vision/desktop/BUILD
new file mode 100644
index 00000000..f61984ee
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/BUILD
@@ -0,0 +1,68 @@
+package(
+ default_visibility = [
+ "//tensorflow_lite_support:users",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_binary(
+ name = "image_classifier_demo",
+ srcs = ["image_classifier_demo.cc"],
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:external_file_handler",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision:image_classifier",
+ "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils",
+ "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+cc_binary(
+ name = "object_detector_demo",
+ srcs = ["object_detector_demo.cc"],
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:external_file_handler",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision:object_detector",
+ "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils",
+ "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+cc_binary(
+ name = "image_segmenter_demo",
+ srcs = ["image_segmenter_demo.cc"],
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/core:external_file_handler",
+ "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision:image_segmenter",
+ "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils",
+ "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/README.md b/tensorflow_lite_support/examples/task/vision/desktop/README.md
new file mode 100644
index 00000000..73c6b637
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/README.md
@@ -0,0 +1,180 @@
+# CLI Demos for C++ Vision Task APIs
+
+This folder contains simple command-line tools for easily trying out the C++
+Vision Task APIs.
+
+## Image Classifier
+
+#### Prerequisites
+
+You will need:
+
+* a TFLite image classification model (e.g. [aiy/vision/classifier/birds_V1][1],
+a bird classification model available on TensorFlow Hub),
+* a PNG, JPEG or GIF image to run classification on, e.g.:
+
+![sparrow](g3doc/sparrow.jpg)
+
+#### Usage
+
+In the console, run:
+
+```bash
+# Download the model:
+curl \
+ -L 'https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3?lite-format=tflite' \
+ -o /tmp/aiy_vision_classifier_birds_V1_3.tflite
+
+# Run the classification tool:
+bazel run -c opt \
+ tensorflow_lite_support/examples/task/vision/desktop:image_classifier_demo -- \
+ --model_path=/tmp/aiy_vision_classifier_birds_V1_3.tflite \
+ --image_path=\
+$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg \
+ --max_results=3
+```
+
+#### Results
+
+In the console, you should get:
+
+```
+Results:
+ Rank #0:
+ index : 671
+ score : 0.91406
+ class name : /m/01bwb9
+ display name: Passer domesticus
+ Rank #1:
+ index : 670
+ score : 0.00391
+ class name : /m/01bwbt
+ display name: Passer montanus
+ Rank #2:
+ index : 495
+ score : 0.00391
+ class name : /m/0bwm6m
+ display name: Passer italiae
+```
+
+## Object Detector
+
+#### Prerequisites
+
+You will need:
+
+* a TFLite object detection model (e.g. [ssd_mobilenet_v1][2], a generic object
+detection model available on TensorFlow Hub),
+* a PNG, JPEG or GIF image to run detection on, e.g.:
+
+![dogs](g3doc/dogs.jpg)
+
+#### Usage
+
+In the console, run:
+
+```bash
+# Download the model:
+curl \
+ -L 'https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite' \
+ -o /tmp/ssd_mobilenet_v1_1_metadata_1.tflite
+
+# Run the detection tool:
+bazel run -c opt \
+ tensorflow_lite_support/examples/task/vision/desktop:object_detector_demo -- \
+ --model_path=/tmp/ssd_mobilenet_v1_1_metadata_1.tflite \
+ --image_path=\
+$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg \
+ --output_png=/tmp/detection-output.png \
+ --max_results=2
+```
+
+#### Results
+
+In the console, you should get:
+
+```
+Results saved to: /tmp/detection-output.png
+Results:
+ Detection #0 (red):
+ Box: (x: 355, y: 133, w: 190, h: 206)
+ Top-1 class:
+ index : 17
+ score : 0.73828
+ class name : dog
+ Detection #1 (green):
+ Box: (x: 103, y: 15, w: 138, h: 369)
+ Top-1 class:
+ index : 17
+ score : 0.73047
+ class name : dog
+```
+
+And `/tmp/detection-output.jpg` should contain:
+
+![detection-output](g3doc/detection-output.png)
+
+## Image Segmenter
+
+#### Prerequisites
+
+You will need:
+
+* a TFLite image segmentation model (e.g. [deeplab_v3][3], a generic
+segmentation model available on TensorFlow Hub),
+* a PNG, JPEG or GIF image to run segmentation on, e.g.:
+
+![plane](g3doc/plane.jpg)
+
+#### Usage
+
+In the console, run:
+
+```bash
+# Download the model:
+curl \
+ -L 'https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1?lite-format=tflite' \
+ -o /tmp/deeplabv3_1_metadata_1.tflite
+
+# Run the segmentation tool:
+bazel run -c opt \
+ tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo -- \
+ --model_path=/tmp/deeplabv3_1_metadata_1.tflite \
+ --image_path=\
+$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg \
+ --output_mask_png=/tmp/segmentation-output.png
+```
+
+#### Results
+
+In the console, you should get:
+
+```
+Category mask saved to: /tmp/segmentation-output.png
+Color Legend:
+ (r: 000, g: 000, b: 000):
+ index : 0
+ class name : background
+ (r: 128, g: 000, b: 000):
+ index : 1
+ class name : aeroplane
+
+# (omitting multiple lines for conciseness) ...
+
+ (r: 128, g: 192, b: 000):
+ index : 19
+ class name : train
+ (r: 000, g: 064, b: 128):
+ index : 20
+ class name : tv
+Tip: use a color picker on the output PNG file to inspect the output mask with
+this legend.
+```
+
+And `/tmp/segmentation-output.jpg` should contain the segmentation mask:
+
+![segmentation-output](g3doc/segmentation-output.png)
+
+[1]: https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3
+[2]: https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2
+[3]: https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.png b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.png
new file mode 100644
index 00000000..c8d56f40
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.png
Binary files differ
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg
new file mode 100644
index 00000000..9db4bee7
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg
Binary files differ
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg
new file mode 100644
index 00000000..0edefa40
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg
Binary files differ
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png
new file mode 100644
index 00000000..e871df33
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png
Binary files differ
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg
new file mode 100644
index 00000000..25d213ea
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg
Binary files differ
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
new file mode 100644
index 00000000..cd97f4a2
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
@@ -0,0 +1,173 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Example usage:
+// bazel run -c opt \
+// tensorflow_lite_support/examples/task/vision/desktop:image_classifier_demo \
+// -- \
+// --model_path=/path/to/model.tflite \
+// --image_path=/path/to/image.jpg
+
+#include <iostream>
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
+
+ABSL_FLAG(std::string, model_path, "",
+ "Absolute path to the '.tflite' image classifier model.");
+ABSL_FLAG(std::string, image_path, "",
+ "Absolute path to the image to classify. The image must be RGB or "
+ "RGBA (grayscale is not supported). The image EXIF orientation "
+ "flag, if any, is NOT taken into account.");
+ABSL_FLAG(int32, max_results, 5,
+ "Maximum number of classification results to display.");
+ABSL_FLAG(float, score_threshold, 0,
+ "Classification results with a confidence score below this value are "
+ "rejected. If >= 0, overrides the score threshold(s) provided in the "
+ "TFLite Model Metadata. Ignored otherwise.");
+ABSL_FLAG(
+ std::vector<std::string>, class_name_whitelist, {},
+ "Comma-separated list of class names that acts as a whitelist. If "
+ "non-empty, classification results whose 'class_name' is not in this list "
+ "are filtered out. Mutually exclusive with 'class_name_blacklist'.");
+ABSL_FLAG(
+ std::vector<std::string>, class_name_blacklist, {},
+ "Comma-separated list of class names that acts as a blacklist. If "
+ "non-empty, classification results whose 'class_name' is in this list "
+ "are filtered out. Mutually exclusive with 'class_name_whitelist'.");
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+ImageClassifierOptions BuildOptions() {
+ ImageClassifierOptions options;
+ options.mutable_model_file_with_metadata()->set_file_name(
+ absl::GetFlag(FLAGS_model_path));
+ options.set_max_results(absl::GetFlag(FLAGS_max_results));
+ if (absl::GetFlag(FLAGS_score_threshold) >= 0) {
+ options.set_score_threshold(absl::GetFlag(FLAGS_score_threshold));
+ }
+ for (const std::string& class_name :
+ absl::GetFlag(FLAGS_class_name_whitelist)) {
+ options.add_class_name_whitelist(class_name);
+ }
+ for (const std::string& class_name :
+ absl::GetFlag(FLAGS_class_name_blacklist)) {
+ options.add_class_name_blacklist(class_name);
+ }
+ return options;
+}
+
+void DisplayResult(const ClassificationResult& result) {
+ std::cout << "Results:\n";
+ for (int head = 0; head < result.classifications_size(); ++head) {
+ if (result.classifications_size() > 1) {
+ std::cout << absl::StrFormat(" Head index %d:\n", head);
+ }
+ const Classifications& classifications = result.classifications(head);
+ for (int rank = 0; rank < classifications.classes_size(); ++rank) {
+ const Class& classification = classifications.classes(rank);
+ std::cout << absl::StrFormat(" Rank #%d:\n", rank);
+ std::cout << absl::StrFormat(" index : %d\n",
+ classification.index());
+ std::cout << absl::StrFormat(" score : %.5f\n",
+ classification.score());
+ if (classification.has_class_name()) {
+ std::cout << absl::StrFormat(" class name : %s\n",
+ classification.class_name());
+ }
+ if (classification.has_display_name()) {
+ std::cout << absl::StrFormat(" display name: %s\n",
+ classification.display_name());
+ }
+ }
+ }
+}
+
+absl::Status Classify() {
+ // Build ImageClassifier.
+ const ImageClassifierOptions& options = BuildOptions();
+ ASSIGN_OR_RETURN(std::unique_ptr<ImageClassifier> image_classifier,
+ ImageClassifier::CreateFromOptions(options));
+
+ // Load image in a FrameBuffer.
+ ASSIGN_OR_RETURN(ImageData image,
+ DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
+ std::unique_ptr<FrameBuffer> frame_buffer;
+ if (image.channels == 3) {
+ frame_buffer =
+ CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
+ } else if (image.channels == 4) {
+ frame_buffer =
+ CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
+ } else {
+ return absl::InvalidArgumentError(absl::StrFormat(
+ "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
+ image.channels));
+ }
+
+ // Run classification and display results.
+ ASSIGN_OR_RETURN(ClassificationResult result,
+ image_classifier->Classify(*frame_buffer));
+ DisplayResult(result);
+
+ // Cleanup and return.
+ ImageDataFree(&image);
+ return absl::OkStatus();
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // Parse command line arguments and perform sanity checks.
+ absl::ParseCommandLine(argc, argv);
+ if (absl::GetFlag(FLAGS_model_path).empty()) {
+ std::cerr << "Missing mandatory 'model_path' argument.\n";
+ return 1;
+ }
+ if (absl::GetFlag(FLAGS_image_path).empty()) {
+ std::cerr << "Missing mandatory 'image_path' argument.\n";
+ return 1;
+ }
+ if (!absl::GetFlag(FLAGS_class_name_whitelist).empty() &&
+ !absl::GetFlag(FLAGS_class_name_blacklist).empty()) {
+ std::cerr << "'class_name_whitelist' and 'class_name_blacklist' arguments "
+ "are mutually exclusive.\n";
+ return 1;
+ }
+
+ // Run classification.
+ absl::Status status = tflite::task::vision::Classify();
+ if (status.ok()) {
+ return 0;
+ } else {
+ std::cerr << "Classification failed: " << status.message() << "\n";
+ return 1;
+ }
+}
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
new file mode 100644
index 00000000..02af824f
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
@@ -0,0 +1,201 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Example usage:
+// bazel run -c opt \
+// tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo \
+// -- \
+// --model_path=/path/to/model.tflite \
+// --image_path=/path/to/image.jpg \
+// --output_mask_png=/path/to/output/mask.png
+
+#include <iostream>
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "absl/status/status.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
+
+ABSL_FLAG(std::string, model_path, "",
+ "Absolute path to the '.tflite' image segmenter model.");
+ABSL_FLAG(std::string, image_path, "",
+ "Absolute path to the image to segment. The image must be RGB or "
+ "RGBA (grayscale is not supported). The image EXIF orientation "
+ "flag, if any, is NOT taken into account.");
+ABSL_FLAG(std::string, output_mask_png, "",
+ "Absolute path to the output category mask (confidence masks outputs "
+ "are not supported by this tool). Must have a '.png' extension.");
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+ImageSegmenterOptions BuildOptions() {
+ ImageSegmenterOptions options;
+ options.mutable_model_file_with_metadata()->set_file_name(
+ absl::GetFlag(FLAGS_model_path));
+ // Confidence masks are not supported by this tool: output_type is set to
+ // CATEGORY_MASK by default.
+ return options;
+}
+
+absl::Status EncodeMaskToPngFile(const SegmentationResult& result) {
+ if (result.segmentation_size() != 1) {
+ return absl::UnimplementedError(
+ "Image segmentation models with multiple output segmentations are not "
+ "supported by this tool.");
+ }
+ const Segmentation& segmentation = result.segmentation(0);
+ // Extract raw mask data as a uint8 pointer.
+ const uint8* raw_mask =
+ reinterpret_cast<const uint8*>(segmentation.category_mask().data());
+
+ // Create RgbImageData for the output mask.
+ uint8* pixel_data = static_cast<uint8*>(
+ malloc(segmentation.width() * segmentation.height() * 3 * sizeof(uint8)));
+ ImageData mask = {.pixel_data = pixel_data,
+ .width = segmentation.width(),
+ .height = segmentation.height(),
+ .channels = 3};
+
+ // Populate RgbImageData from the raw mask and ColoredLabel-s.
+ for (int i = 0; i < segmentation.width() * segmentation.height(); ++i) {
+ Segmentation::ColoredLabel colored_label =
+ segmentation.colored_labels(raw_mask[i]);
+ pixel_data[3 * i] = colored_label.r();
+ pixel_data[3 * i + 1] = colored_label.g();
+ pixel_data[3 * i + 2] = colored_label.b();
+ }
+
+ // Encode mask as PNG.
+ RETURN_IF_ERROR(
+ EncodeImageToPngFile(mask, absl::GetFlag(FLAGS_output_mask_png)));
+ std::cout << absl::StrFormat("Category mask saved to: %s\n",
+ absl::GetFlag(FLAGS_output_mask_png));
+
+ // Cleanup and return.
+ ImageDataFree(&mask);
+ return absl::OkStatus();
+}
+
+absl::Status DisplayColorLegend(const SegmentationResult& result) {
+ if (result.segmentation_size() != 1) {
+ return absl::UnimplementedError(
+ "Image segmentation models with multiple output segmentations are not "
+ "supported by this tool.");
+ }
+ const Segmentation& segmentation = result.segmentation(0);
+ const int num_labels = segmentation.colored_labels_size();
+
+ std::cout << "Color Legend:\n";
+ for (int index = 0; index < num_labels; ++index) {
+ Segmentation::ColoredLabel colored_label =
+ segmentation.colored_labels(index);
+ std::cout << absl::StrFormat(" (r: %03d, g: %03d, b: %03d):\n",
+ colored_label.r(), colored_label.g(),
+ colored_label.b());
+ std::cout << absl::StrFormat(" index : %d\n", index);
+ if (colored_label.has_class_name()) {
+ std::cout << absl::StrFormat(" class name : %s\n",
+ colored_label.class_name());
+ }
+ if (colored_label.has_display_name()) {
+ std::cout << absl::StrFormat(" display name: %s\n",
+ colored_label.display_name());
+ }
+ }
+ std::cout << "Tip: use a color picker on the output PNG file to inspect the "
+ "output mask with this legend.\n";
+
+ return absl::OkStatus();
+}
+
+absl::Status Segment() {
+ // Build ImageClassifier.
+ const ImageSegmenterOptions& options = BuildOptions();
+ ASSIGN_OR_RETURN(std::unique_ptr<ImageSegmenter> image_segmenter,
+ ImageSegmenter::CreateFromOptions(options));
+
+ // Load image in a FrameBuffer.
+ ASSIGN_OR_RETURN(ImageData image,
+ DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
+ std::unique_ptr<FrameBuffer> frame_buffer;
+ if (image.channels == 3) {
+ frame_buffer =
+ CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
+ } else if (image.channels == 4) {
+ frame_buffer =
+ CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
+ } else {
+ return absl::InvalidArgumentError(absl::StrFormat(
+ "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
+ image.channels));
+ }
+
+ // Run segmentation and save category mask.
+ ASSIGN_OR_RETURN(SegmentationResult result,
+ image_segmenter->Segment(*frame_buffer));
+ RETURN_IF_ERROR(EncodeMaskToPngFile(result));
+
+ // Display the legend.
+ RETURN_IF_ERROR(DisplayColorLegend(result));
+
+ // Cleanup and return.
+ ImageDataFree(&image);
+ return absl::OkStatus();
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // Parse command line arguments and perform sanity checks.
+ absl::ParseCommandLine(argc, argv);
+ if (absl::GetFlag(FLAGS_model_path).empty()) {
+ std::cerr << "Missing mandatory 'model_path' argument.\n";
+ return 1;
+ }
+ if (absl::GetFlag(FLAGS_image_path).empty()) {
+ std::cerr << "Missing mandatory 'image_path' argument.\n";
+ return 1;
+ }
+ if (absl::GetFlag(FLAGS_output_mask_png).empty()) {
+ std::cerr << "Missing mandatory 'output_mask_png' argument.\n";
+ return 1;
+ }
+ if (!absl::EndsWithIgnoreCase(absl::GetFlag(FLAGS_output_mask_png), ".png")) {
+ std::cerr << "Argument 'output_mask_png' must end with '.png' or '.PNG'\n";
+ return 1;
+ }
+
+ // Run segmentation.
+ absl::Status status = tflite::task::vision::Segment();
+ if (status.ok()) {
+ return 0;
+ } else {
+ std::cerr << "Segmentation failed: " << status.message() << "\n";
+ return 1;
+ }
+}
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
new file mode 100644
index 00000000..b7ab651a
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
@@ -0,0 +1,251 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Example usage:
+// bazel run -c opt \
+// tensorflow_lite_support/examples/task/vision/desktop:object_detector_demo \
+// -- \
+// --model_path=/path/to/model.tflite \
+// --image_path=/path/to/image.jpg \
+// --output_png=/path/to/output.png
+
+#include <iostream>
+#include <limits>
+
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "absl/status/status.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/object_detector.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
+
+ABSL_FLAG(std::string, model_path, "",
+ "Absolute path to the '.tflite' object detector model.");
+ABSL_FLAG(std::string, image_path, "",
+ "Absolute path to the image to run detection on. The image must be "
+ "RGB or RGBA (grayscale is not supported). The image EXIF "
+ "orientation flag, if any, is NOT taken into account.");
+ABSL_FLAG(std::string, output_png, "",
+ "Absolute path to a file where to draw the detection results on top "
+ "of the input image. Must have a '.png' extension.");
+ABSL_FLAG(int32, max_results, 5,
+ "Maximum number of detection results to display.");
+ABSL_FLAG(
+ float, score_threshold, std::numeric_limits<float>::lowest(),
+ "Detection results with a confidence score below this value are "
+ "rejected. If specified, overrides the score threshold(s) provided in the "
+ "TFLite Model Metadata. Ignored otherwise.");
+ABSL_FLAG(
+ std::vector<std::string>, class_name_whitelist, {},
+ "Comma-separated list of class names that acts as a whitelist. If "
+ "non-empty, detections results whose 'class_name' is not in this list "
+ "are filtered out. Mutually exclusive with 'class_name_blacklist'.");
+ABSL_FLAG(std::vector<std::string>, class_name_blacklist, {},
+ "Comma-separated list of class names that acts as a blacklist. If "
+ "non-empty, detections results whose 'class_name' is in this list "
+ "are filtered out. Mutually exclusive with 'class_name_whitelist'.");
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+namespace {
+// The line thickness (in pixels) for drawing the detection results.
+constexpr int kLineThickness = 3;
+
+// The number of colors used for drawing the detection results.
+constexpr int kColorMapSize = 10;
+
+// The names of the colors used for drawing the detection results.
+constexpr std::array<absl::string_view, 10> kColorMapNames = {
+ "red", "green", "blue", "yellow", "fuschia",
+ "dark red", "dark green", "dark blue", "gray", "black"};
+
+// The colors used for drawing the detection results as a flattened array of
+// {R,G,B} components.
+constexpr uint8 kColorMapComponents[30] = {
+ 255, 0, 0, 0, 255, 0, 0, 0, 255, 255, 255, 0, 255, 0, 255,
+ 128, 0, 0, 0, 128, 0, 0, 0, 128, 128, 128, 128, 0, 0, 0};
+} // namespace
+
+ObjectDetectorOptions BuildOptions() {
+ ObjectDetectorOptions options;
+ options.mutable_model_file_with_metadata()->set_file_name(
+ absl::GetFlag(FLAGS_model_path));
+ options.set_max_results(absl::GetFlag(FLAGS_max_results));
+ if (absl::GetFlag(FLAGS_score_threshold) >
+ std::numeric_limits<float>::lowest()) {
+ options.set_score_threshold(absl::GetFlag(FLAGS_score_threshold));
+ }
+ for (const std::string& class_name :
+ absl::GetFlag(FLAGS_class_name_whitelist)) {
+ options.add_class_name_whitelist(class_name);
+ }
+ for (const std::string& class_name :
+ absl::GetFlag(FLAGS_class_name_blacklist)) {
+ options.add_class_name_blacklist(class_name);
+ }
+ return options;
+}
+
+absl::Status EncodeResultToPngFile(const DetectionResult& result,
+ const ImageData* image) {
+ for (int index = 0; index < result.detections_size(); ++index) {
+ // Get bounding box as left, top, right, bottom.
+ const BoundingBox& box = result.detections(index).bounding_box();
+ const int left = box.origin_x();
+ const int top = box.origin_y();
+ const int right = box.origin_x() + box.width();
+ const int bottom = box.origin_y() + box.height();
+ // Get color components.
+ const uint8 r = kColorMapComponents[3 * (index % kColorMapSize)];
+ const uint8 g = kColorMapComponents[3 * (index % kColorMapSize) + 1];
+ const uint8 b = kColorMapComponents[3 * (index % kColorMapSize) + 2];
+ // Draw. Boxes might have coordinates outside of [0, w( x [0, h( so clamping
+ // is applied.
+ for (int y = std::max(0, top); y < std::min(image->height, bottom); ++y) {
+ for (int x = std::max(0, left); x < std::min(image->width, right); ++x) {
+ int pixel_index = image->channels * (image->width * y + x);
+ if (x < left + kLineThickness || x > right - kLineThickness ||
+ y < top + kLineThickness || y > bottom - kLineThickness) {
+ image->pixel_data[pixel_index] = r;
+ image->pixel_data[pixel_index + 1] = g;
+ image->pixel_data[pixel_index + 2] = b;
+ }
+ }
+ }
+ }
+ // Encode to PNG and return.
+ RETURN_IF_ERROR(
+ EncodeImageToPngFile(*image, absl::GetFlag(FLAGS_output_png)));
+ std::cout << absl::StrFormat("Results saved to: %s\n",
+ absl::GetFlag(FLAGS_output_png));
+ return absl::OkStatus();
+}
+
+void DisplayResult(const DetectionResult& result) {
+ std::cout << "Results:\n";
+ for (int index = 0; index < result.detections_size(); ++index) {
+ std::cout << absl::StrFormat(" Detection #%d (%s):\n", index,
+ kColorMapNames[index % kColorMapSize]);
+ const Detection& detection = result.detections(index);
+ const BoundingBox& box = detection.bounding_box();
+ std::cout << absl::StrFormat(" Box: (x: %d, y: %d, w: %d, h: %d)\n",
+ box.origin_x(), box.origin_y(), box.width(),
+ box.height());
+ if (detection.classes_size() == 0) {
+ std::cout << " No top-1 class available";
+ } else {
+ std::cout << " Top-1 class:\n";
+ const Class& classification = detection.classes(0);
+ std::cout << absl::StrFormat(" index : %d\n",
+ classification.index());
+ std::cout << absl::StrFormat(" score : %.5f\n",
+ classification.score());
+ if (classification.has_class_name()) {
+ std::cout << absl::StrFormat(" class name : %s\n",
+ classification.class_name());
+ }
+ if (classification.has_display_name()) {
+ std::cout << absl::StrFormat(" display name: %s\n",
+ classification.display_name());
+ }
+ }
+ }
+}
+
+absl::Status Detect() {
+ // Build ObjectDetector.
+ const ObjectDetectorOptions& options = BuildOptions();
+ ASSIGN_OR_RETURN(std::unique_ptr<ObjectDetector> object_detector,
+ ObjectDetector::CreateFromOptions(options));
+
+ // Load image in a FrameBuffer.
+ ASSIGN_OR_RETURN(ImageData image,
+ DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
+ std::unique_ptr<FrameBuffer> frame_buffer;
+ if (image.channels == 3) {
+ frame_buffer =
+ CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
+ } else if (image.channels == 4) {
+ frame_buffer =
+ CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
+ } else {
+ return absl::InvalidArgumentError(absl::StrFormat(
+ "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
+ image.channels));
+ }
+
+ // Run object detection and draw results on input image.
+ ASSIGN_OR_RETURN(DetectionResult result,
+ object_detector->Detect(*frame_buffer));
+ RETURN_IF_ERROR(EncodeResultToPngFile(result, &image));
+
+ // Display results as text.
+ DisplayResult(result);
+
+ // Cleanup and return.
+ ImageDataFree(&image);
+ return absl::OkStatus();
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // Parse command line arguments and perform sanity checks.
+ absl::ParseCommandLine(argc, argv);
+ if (absl::GetFlag(FLAGS_model_path).empty()) {
+ std::cerr << "Missing mandatory 'model_path' argument.\n";
+ return 1;
+ }
+ if (absl::GetFlag(FLAGS_image_path).empty()) {
+ std::cerr << "Missing mandatory 'image_path' argument.\n";
+ return 1;
+ }
+ if (absl::GetFlag(FLAGS_output_png).empty()) {
+ std::cerr << "Missing mandatory 'output_png' argument.\n";
+ return 1;
+ }
+ if (!absl::EndsWithIgnoreCase(absl::GetFlag(FLAGS_output_png), ".png")) {
+ std::cerr << "Argument 'output_png' must end with '.png' or '.PNG'\n";
+ return 1;
+ }
+ if (!absl::GetFlag(FLAGS_class_name_whitelist).empty() &&
+ !absl::GetFlag(FLAGS_class_name_blacklist).empty()) {
+ std::cerr << "'class_name_whitelist' and 'class_name_blacklist' arguments "
+ "are mutually exclusive.\n";
+ return 1;
+ }
+
+ // Run detection.
+ absl::Status status = tflite::task::vision::Detect();
+ if (status.ok()) {
+ return 0;
+ } else {
+ std::cerr << "Detection failed: " << status.message() << "\n";
+ return 1;
+ }
+}
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD b/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD
new file mode 100644
index 00000000..9a837e2d
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD
@@ -0,0 +1,22 @@
+package(
+ default_visibility = [
+ "//tensorflow_lite_support:users",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "image_utils",
+ srcs = ["image_utils.cc"],
+ hdrs = ["image_utils.h"],
+ deps = [
+ "//tensorflow_lite_support/cc/port:integral_types",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@stblib//:stb_image",
+ "@stblib//:stb_image_write",
+ ],
+)
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc
new file mode 100644
index 00000000..7c4604e9
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc
@@ -0,0 +1,94 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
+
+#include <cstdlib>
+#include <cstring>
+#include <vector>
+
+// These need to be defined for stb_image.h and stb_image_write.h to include
+// the actual implementations of image decoding/encoding functions.
+#define STB_IMAGE_IMPLEMENTATION
+#define STB_IMAGE_WRITE_IMPLEMENTATION
+
+#include "absl/status/status.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_format.h"
+#include "stb_image.h"
+#include "stb_image_write.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+using ::tflite::support::StatusOr;
+
+StatusOr<ImageData> DecodeImageFromFile(const std::string& file_name) {
+ ImageData image_data;
+ image_data.pixel_data = stbi_load(file_name.c_str(), &image_data.width,
+ &image_data.height, &image_data.channels,
+ /*desired_channels=*/0);
+ if (image_data.pixel_data == nullptr) {
+ return absl::InternalError(absl::StrFormat(
+ "An error occurred while decoding image: %s", stbi_failure_reason()));
+ }
+ if (image_data.channels != 1 && image_data.channels != 3 &&
+ image_data.channels != 4) {
+ stbi_image_free(image_data.pixel_data);
+ return absl::UnimplementedError(
+ absl::StrFormat("Expected image with 1 (grayscale), 3 (RGB) or 4 "
+ "(RGBA) channels, found %d",
+ image_data.channels));
+ }
+ return image_data;
+}
+
+absl::Status EncodeImageToPngFile(const ImageData& image_data,
+ const std::string& image_path) {
+ // Sanity check inputs.
+ if (image_data.width <= 0 || image_data.height <= 0) {
+ return absl::InvalidArgumentError(
+ absl::StrFormat("Expected positive image dimensions, found %d x %d.",
+ image_data.width, image_data.height));
+ }
+ if (image_data.channels != 1 && image_data.channels != 3 &&
+ image_data.channels != 4) {
+ return absl::UnimplementedError(
+ absl::StrFormat("Expected image data with 1 (grayscale), 3 (RGB) or 4 "
+ "(RGBA) channels, found %d",
+ image_data.channels));
+ }
+ if (image_data.pixel_data == nullptr) {
+ return absl::InvalidArgumentError(
+ "Expected pixel data to be set, found nullptr.");
+ }
+
+ if (stbi_write_png(
+ image_path.c_str(), image_data.width, image_data.height,
+ image_data.channels, image_data.pixel_data,
+ /*stride_in_bytes=*/image_data.width * image_data.channels) == 0) {
+ return absl::InternalError("An error occurred while encoding image.");
+ }
+
+ return absl::OkStatus();
+}
+
+void ImageDataFree(ImageData* image) { stbi_image_free(image->pixel_data); }
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h
new file mode 100644
index 00000000..38f62b60
--- /dev/null
+++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h
@@ -0,0 +1,58 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Image data with pixels stored as a row-major flattened array.
+// Channels can be:
+// 1 : grayscale
+// 3 : RGB, interleaved
+// 4 : RGBA, interleaved
+struct ImageData {
+ uint8* pixel_data;
+ int width;
+ int height;
+ int channels;
+};
+
+// Decodes image file and returns the corresponding image if no error
+// occurred. If decoding succeeded, the caller must manage deletion of the
+// underlying pixel data using `ImageDataFree`.
+// Supports a wide range of image formats, listed in `stb_image/stb_image.h`.
+tflite::support::StatusOr<ImageData> DecodeImageFromFile(
+ const std::string& file_name);
+
+// Encodes the image provided as an ImageData as lossless PNG to the provided
+// path.
+absl::Status EncodeImageToPngFile(const ImageData& image_data,
+ const std::string& image_path);
+
+// Releases image pixel data memory.
+void ImageDataFree(ImageData* image);
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_
diff --git a/tensorflow_lite_support/ios/BUILD b/tensorflow_lite_support/ios/BUILD
new file mode 100644
index 00000000..07b89515
--- /dev/null
+++ b/tensorflow_lite_support/ios/BUILD
@@ -0,0 +1,48 @@
+# TensorFlow Lite Task Library - Text
+
+load(
+ "@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl",
+ "TFL_MINIMUM_OS_VERSION",
+ "tflite_ios_static_framework",
+)
+load(
+ "//tensorflow_lite_support/ios:ios.bzl",
+ "strip_c_api_include_path_prefix",
+)
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+strip_c_api_include_path_prefix(
+ name = "strip_c_api_include_path",
+ hdr_labels = [
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier_c_api.h",
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api.h",
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api_common.h",
+ "//tensorflow_lite_support/cc/task/text/qa:bert_qa_c_api.h",
+ ],
+)
+
+# This target builds a monolithic static framework for the TFLite Text API,
+# which includes the TFLite runtime in it.
+#
+# bazel build -c opt --config=ios_fat //tensorflow_lite_support/ios:TensorFlowLiteTaskTextC_framework
+tflite_ios_static_framework(
+ name = "TensorFlowLiteTaskTextC_framework",
+ hdrs = [
+ ":bert_nl_classifier_c_api.h",
+ ":bert_qa_c_api.h",
+ ":nl_classifier_c_api.h",
+ ":nl_classifier_c_api_common.h",
+ ],
+ allowlist_symbols_file = ":allowlist_TensorFlowLiteTaskText.txt",
+ bundle_name = "TensorFlowLiteTaskTextC",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ deps = [
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier_c_api",
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api",
+ "//tensorflow_lite_support/cc/task/text/qa:bert_qa_c_api",
+ ],
+)
diff --git a/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template b/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template
new file mode 100644
index 00000000..62c3f339
--- /dev/null
+++ b/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template
@@ -0,0 +1,44 @@
+Pod::Spec.new do |s|
+ s.name = 'TensorFlowLiteTaskText'
+ s.version = '${TFLS_BUILD_VERSION}'
+ s.authors = 'Google Inc.'
+ s.license = { :type => 'Apache' }
+ s.homepage = 'https://github.com/tensorflow/tflite-support'
+ s.source = { :http => '${TFLS_DOWNLOAD_URL}' }
+ s.summary = 'TensorFlow Lite Task Library - Text'
+ s.description = 'The Natural Language APIs of the TFLite Task Library'
+
+ s.ios.deployment_target = '9.0'
+
+ s.module_name = 'TensorFlowLiteTaskText'
+ s.static_framework = true
+
+ s.dependency 'GoogleToolboxForMac', '2.2.1'
+
+ objc_dir = 'tensorflow_lite_support/ios/task/text/'
+ s.public_header_files = [
+ objc_dir + 'apis/*.h',
+ objc_dir + '{nlclassifier,qa}/Sources/*.h'
+ ]
+
+ cc_dir = 'tensorflow_lite_support/cc/task/text/'
+ s.source_files = [
+ cc_dir + '{nlclassifier,qa}/*_c_api*.h',
+ objc_dir + 'apis/*.h',
+ objc_dir + '{nlclassifier,qa}/Sources/*.{h,m,mm}'
+ ]
+ s.module_map = objc_dir + 'apis/framework.modulemap'
+ s.pod_target_xcconfig = {
+ 'HEADER_SEARCH_PATHS' =>
+ '"${PODS_TARGET_SRCROOT}" ' +
+ '"${PODS_TARGET_SRCROOT}/' + cc_dir + 'nlclassifier" ' +
+ '"${PODS_TARGET_SRCROOT}/' + cc_dir + 'qa" ' +
+ '"${PODS_TARGET_SRCROOT}/' + objc_dir + 'apis" ' +
+ '"${PODS_TARGET_SRCROOT}/' + objc_dir + 'nlclassifier/Sources" ' +
+ '"${PODS_TARGET_SRCROOT}/' + objc_dir + 'qa/Sources"',
+ 'VALID_ARCHS' => 'x86_64 armv7 arm64',
+ }
+
+ s.library = 'c++'
+ s.vendored_frameworks = 'Frameworks/TensorFlowLiteTaskTextC.framework'
+end
diff --git a/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt b/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt
new file mode 100644
index 00000000..3af5b0b1
--- /dev/null
+++ b/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt
@@ -0,0 +1,3 @@
+_NLClassifier*
+_BertNLClassifier*
+_BertQuestionAnswerer*
diff --git a/tensorflow_lite_support/ios/ios.bzl b/tensorflow_lite_support/ios/ios.bzl
new file mode 100644
index 00000000..cb8c92ac
--- /dev/null
+++ b/tensorflow_lite_support/ios/ios.bzl
@@ -0,0 +1,30 @@
+"""TensorFlow Lite Support Library Helper Rules for iOS"""
+
+# When the static framework is built with bazel, the all header files are moved
+# to the "Headers" directory with no header path prefixes. This auxiliary rule
+# is used for stripping the path prefix to the C API header files included by
+# other C API header files.
+def strip_c_api_include_path_prefix(name, hdr_labels, prefix = ""):
+ """Create modified header files with the common.h include path stripped out.
+
+ Args:
+ name: The name to be used as a prefix to the generated genrules.
+ hdr_labels: List of header labels to strip out the include path. Each
+ label must end with a colon followed by the header file name.
+ prefix: Optional prefix path to prepend to the header inclusion path.
+ """
+
+ for hdr_label in hdr_labels:
+ hdr_filename = hdr_label.split(":")[-1]
+ hdr_basename = hdr_filename.split(".")[0]
+
+ native.genrule(
+ name = "{}_{}".format(name, hdr_basename),
+ srcs = [hdr_label],
+ outs = [hdr_filename],
+ cmd = """
+ sed 's|#include ".*/\\([^/]\\{{1,\\}}\\.h\\)"|#include "{}\\1"|'\
+ "$(location {})"\
+ > "$@"
+ """.format(prefix, hdr_label),
+ )
diff --git a/tensorflow_lite_support/ios/task/text/apis/TFLTaskText.h b/tensorflow_lite_support/ios/task/text/apis/TFLTaskText.h
new file mode 100644
index 00000000..a42a4b38
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/apis/TFLTaskText.h
@@ -0,0 +1,17 @@
+// Copyright 2020 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at:
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#import "TFLBertNLClassifier.h"
+#import "TFLBertQuestionAnswerer.h"
+#import "TFLNLClassifier.h"
diff --git a/tensorflow_lite_support/ios/task/text/apis/framework.modulemap b/tensorflow_lite_support/ios/task/text/apis/framework.modulemap
new file mode 100644
index 00000000..3e267620
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/apis/framework.modulemap
@@ -0,0 +1,4 @@
+framework module TensorFlowLiteTaskText {
+ umbrella header "TFLTaskText.h"
+ export *
+}
diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD b/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD
new file mode 100644
index 00000000..fb369e90
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD
@@ -0,0 +1,125 @@
+load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
+load("@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION")
+load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
+load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+objc_library(
+ name = "TFLBertNLClassifier",
+ srcs = ["Sources/TFLBertNLClassifier.m"],
+ hdrs = ["Sources/TFLBertNLClassifier.h"],
+ module_name = "TFLBertNLClassifier",
+ deps = [
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier_c_api",
+ "@google_toolbox_for_mac//:GTM_Defines",
+ ],
+)
+
+swift_library(
+ name = "TFLBertNLClassifierSwiftTestLibrary",
+ testonly = 1,
+ srcs = ["Tests/TFLBertNLClassifierTest.swift"],
+ data = [
+ "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models",
+ ],
+ tags = TFL_DEFAULT_TAGS,
+ deps = [
+ ":TFLBertNLClassifier",
+ "//third_party/swift/xctest",
+ ],
+)
+
+ios_unit_test(
+ name = "TFLBertNLClassifierSwiftTest",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ runner = tflite_ios_lab_runner("IOS_LATEST"),
+ tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
+ deps = [
+ ":TFLBertNLClassifierSwiftTestLibrary",
+ ],
+)
+
+objc_library(
+ name = "TFLBertNLClassifierObjcTestLibrary",
+ testonly = 1,
+ srcs = ["Tests/TFLBertNLClassifierTest.m"],
+ data = [
+ "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models",
+ ],
+ tags = TFL_DEFAULT_TAGS,
+ deps = [
+ ":TFLBertNLClassifier",
+ ],
+)
+
+ios_unit_test(
+ name = "TFLBertNLClassifierObjcTest",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ runner = tflite_ios_lab_runner("IOS_LATEST"),
+ tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
+ deps = [
+ ":TFLBertNLClassifierObjcTestLibrary",
+ ],
+)
+
+objc_library(
+ name = "TFLNLClassifier",
+ srcs = ["Sources/TFLNLClassifier.m"],
+ hdrs = ["Sources/TFLNLClassifier.h"],
+ module_name = "TFLNLClassifier",
+ deps = [
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api",
+ "@google_toolbox_for_mac//:GTM_Defines",
+ ],
+)
+
+swift_library(
+ name = "TFLNLClassifierSwiftTestLibrary",
+ testonly = 1,
+ srcs = ["Tests/TFLNLClassifierTest.swift"],
+ data = [
+ "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models",
+ ],
+ tags = TFL_DEFAULT_TAGS,
+ deps = [
+ ":TFLNLClassifier",
+ "//third_party/swift/xctest",
+ ],
+)
+
+ios_unit_test(
+ name = "TFLNLClassifierSwiftTest",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ runner = tflite_ios_lab_runner("IOS_LATEST"),
+ tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
+ deps = [
+ ":TFLNLClassifierSwiftTestLibrary",
+ ],
+)
+
+objc_library(
+ name = "TFLNLClassifierObjcTestLibrary",
+ testonly = 1,
+ srcs = ["Tests/TFLNLClassifierTest.m"],
+ data = [
+ "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models",
+ ],
+ tags = TFL_DEFAULT_TAGS,
+ deps = [
+ ":TFLNLClassifier",
+ ],
+)
+
+ios_unit_test(
+ name = "TFLNLClassifierObjcTest",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ runner = tflite_ios_lab_runner("IOS_LATEST"),
+ tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
+ deps = [
+ ":TFLNLClassifierObjcTestLibrary",
+ ],
+)
diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
new file mode 100644
index 00000000..ceed6fa8
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h
@@ -0,0 +1,51 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import <Foundation/Foundation.h>
+
+NS_ASSUME_NONNULL_BEGIN
+
+/**
+ * Classifier API for NLClassification tasks with Bert models, categorizes string into different
+ * classes. The API expects a Bert based TFLite model with metadata populated.
+ *
+ * The metadata should contain the following information:
+ * 1 input_process_unit for Wordpiece/Sentencepiece Tokenizer.
+ * 3 input tensors with names "ids", "mask" and "segment_ids".
+ * 1 output tensor of type float32[1, 2], with a optionally attached label file. If a label
+ * file is attached, the file should be a plain text file with one label per line, the number
+ * of labels should match the number of categories the model outputs.
+ */
+@interface TFLBertNLClassifier : NSObject
+
+/**
+ * Creates TFLBertNLClassifier from a model file.
+ *
+ * @param modelPath Path to the classification model.
+ * @return A TFLBertNLClassifier instance.
+ */
++ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath
+ NS_SWIFT_NAME(bertNLClassifier(modelPath:));
+
+/**
+ * Performs classification on a NSString input, returns <NSString *, NSNumber *>
+ * for categories and socres.
+ *
+ * @param text input text to the model.
+ * @return A NSDictionary of categorization results.
+ */
+- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text
+ NS_SWIFT_NAME(classify(text:));
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m
new file mode 100644
index 00000000..24c37b78
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m
@@ -0,0 +1,60 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h"
+#import "GTMDefines.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h"
+
+NS_ASSUME_NONNULL_BEGIN
+
+@interface TFLBertNLClassifier ()
+/** BertNLClassifier backed by C API */
+@property(nonatomic) BertNLClassifier *bertNLClassifier;
+@end
+
+@implementation TFLBertNLClassifier
+
+- (void)dealloc {
+ BertNLClassifierDelete(_bertNLClassifier);
+}
+
++ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath {
+ BertNLClassifier *classifier = BertNLClassifierFromFile(modelPath.UTF8String);
+
+ _GTMDevAssert(classifier, @"Failed to create BertNLClassifier");
+ return [[TFLBertNLClassifier alloc] initWithBertNLClassifier:classifier];
+}
+
+- (instancetype)initWithBertNLClassifier:(BertNLClassifier *)bertNLClassifier {
+ self = [super init];
+ if (self) {
+ _bertNLClassifier = bertNLClassifier;
+ }
+ return self;
+}
+
+- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text {
+ struct Categories *cCategories = BertNLClassifierClassify(_bertNLClassifier, text.UTF8String);
+ NSMutableDictionary<NSString *, NSNumber *> *ret = [NSMutableDictionary dictionary];
+ for (int i = 0; i < cCategories->size; i++) {
+ struct Category cCategory = cCategories->categories[i];
+ [ret setValue:[NSNumber numberWithDouble:cCategory.score]
+ forKey:[NSString stringWithUTF8String:cCategory.text]];
+ }
+ NLClassifierCategoriesDelete(cCategories);
+ return ret;
+}
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
new file mode 100644
index 00000000..ceb8d2ef
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h
@@ -0,0 +1,86 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import <Foundation/Foundation.h>
+
+NS_ASSUME_NONNULL_BEGIN
+
+/**
+ * Options to identify input and output tensors of the model.
+ */
+@interface TFLNLClassifierOptions : NSObject
+@property(nonatomic) int inputTensorIndex;
+@property(nonatomic) int outputScoreTensorIndex;
+@property(nonatomic) int outputLabelTensorIndex;
+@property(nonatomic) NSString *inputTensorName;
+@property(nonatomic) NSString *outputScoreTensorName;
+@property(nonatomic) NSString *outputLabelTensorName;
+@end
+
+/**
+ * Classifier API for natural language classification tasks, categorizes string into different
+ * classes.
+ *
+ * The API expects a TFLite model with the following input/output tensor:
+ *
+ * Input tensor (kTfLiteString)
+ * input of the model, accepts a string.
+ *
+ * Output score tensor
+ * (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool)
+ * output scores for each class, if type is one of the Int types, dequantize it, if it
+ * is Bool type, convert the values to 0.0 and 1.0 respectively.
+ *
+ * can have an optional associated file in metadata for labels, the file should be a
+ * plain text file with one label per line, the number of labels should match the number
+ * of categories the model outputs. Output label tensor: optional (kTfLiteString) -
+ * output classname for each class, should be of the same length with scores. If this
+ * tensor is not present, the API uses score indices as classnames. - will be ignored if
+ * output score tensor already has an associated label file.
+ *
+ * Optional Output label tensor (kTfLiteString/kTfLiteInt32)
+ * output classname for each class, should be of the same length with scores. If this
+ * tensor is not present, the API uses score indices as classnames.
+ *
+ * will be ignored if output score tensor already has an associated labe file.
+ *
+ * By default the API tries to find the input/output tensors with default configurations in
+ * TFLNLClassifierOptions, with tensor name prioritized over tensor index. The option is
+ * configurable for different TFLite models.
+ */
+@interface TFLNLClassifier : NSObject
+
+/**
+ * Creates a TFLNLClassifier instance from TFLNLClassifierOptions.
+ *
+ * @param modelPath The file path to the tflite mdoel.
+ * @param options The TFLNLClassifierOptions to configure the model.
+ *
+ * @return A TFLNLClassifier instance.
+ */
++ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath
+ options:(TFLNLClassifierOptions *)options
+ NS_SWIFT_NAME(nlClassifier(modelPath:options:));
+
+/**
+ * Performs classification on a NSString input, returns <NSString *, NSNumber *>
+ * for categories and socres.
+ *
+ * @param text input text to the model.
+ * @return A NSDictionary of categorization results.
+ */
+- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text
+ NS_SWIFT_NAME(classify(text:));
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m
new file mode 100644
index 00000000..01a48188
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m
@@ -0,0 +1,79 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h"
+#import "GTMDefines.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h"
+
+NS_ASSUME_NONNULL_BEGIN
+
+@implementation TFLNLClassifierOptions
+@synthesize inputTensorIndex;
+@synthesize outputScoreTensorIndex;
+@synthesize outputLabelTensorIndex;
+@synthesize inputTensorName;
+@synthesize outputScoreTensorName;
+@synthesize outputLabelTensorName;
+@end
+
+@interface TFLNLClassifier ()
+/** NLClassifier backed by C API */
+@property(nonatomic) NLClassifier *nlClassifier;
+@end
+
+@implementation TFLNLClassifier
+
+- (void)dealloc {
+ NLClassifierDelete(_nlClassifier);
+}
+
++ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath
+ options:(TFLNLClassifierOptions *)options {
+ struct NLClassifierOptions cOptions = {
+ .input_tensor_index = options.inputTensorIndex,
+ .output_score_tensor_index = options.outputScoreTensorIndex,
+ .output_label_tensor_index = options.outputLabelTensorIndex,
+ .input_tensor_name = options.inputTensorName.UTF8String,
+ .output_score_tensor_name =
+ options.outputScoreTensorName.UTF8String,
+ .output_label_tensor_name =
+ options.outputLabelTensorName.UTF8String
+ };
+ NLClassifier *classifier = NLClassifierFromFileAndOptions(modelPath.UTF8String, &cOptions);
+ _GTMDevAssert(classifier, @"Failed to create NLClassifier");
+ return [[TFLNLClassifier alloc] initWithNLClassifier:classifier];
+}
+
+- (instancetype)initWithNLClassifier:(NLClassifier *)nlClassifier {
+ self = [super init];
+ if (self) {
+ _nlClassifier = nlClassifier;
+ }
+ return self;
+}
+
+- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text {
+ struct Categories *cCategories = NLClassifierClassify(_nlClassifier, text.UTF8String);
+ NSMutableDictionary<NSString *, NSNumber *> *ret = [NSMutableDictionary dictionary];
+ for (int i = 0; i < cCategories->size; i++) {
+ struct Category cCategory = cCategories->categories[i];
+ [ret setValue:[NSNumber numberWithDouble:cCategory.score]
+ forKey:[NSString stringWithUTF8String:cCategory.text]];
+ }
+ NLClassifierCategoriesDelete(cCategories);
+ return ret;
+}
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m
new file mode 100644
index 00000000..9565bfb2
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m
@@ -0,0 +1,61 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h"
+
+#import <XCTest/XCTest.h>
+
+NS_ASSUME_NONNULL_BEGIN
+
+@interface TFLBertNLClassifierTest : XCTestCase
+@property(nonatomic, nullable) NSString *bertModelPath;
+@end
+
+@implementation TFLBertNLClassifierTest
+#pragma mark - Tests
+
+- (void)setUp {
+ [super setUp];
+ NSBundle *bundle = [NSBundle bundleForClass:[self class]];
+ self.bertModelPath = [bundle pathForResource:@"test_model_nl_classifier_bert"
+ ofType:@"tflite"];
+}
+
+- (void)testClassifyPositiveResult {
+ TFLBertNLClassifier* bertNLClassifier =
+ [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath];
+
+ XCTAssertNotNil(bertNLClassifier);
+
+ NSDictionary<NSString *, NSNumber *> * categories =
+ [bertNLClassifier classifyWithText:@"it's a charming and often affecting journey"];
+
+ XCTAssertGreaterThan([categories[@"positive"] doubleValue],
+ [categories[@"negative"] doubleValue]);
+}
+
+- (void)testClassifyNegativeResult {
+ TFLBertNLClassifier* bertNLClassifier =
+ [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath];
+
+ XCTAssertNotNil(bertNLClassifier);
+
+ NSDictionary<NSString *, NSNumber *> * categories =
+ [bertNLClassifier classifyWithText:@"unflinchingly bleak and desperate"];
+
+ XCTAssertGreaterThan([categories[@"negative"] doubleValue],
+ [categories[@"positive"] doubleValue]);
+}
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift
new file mode 100644
index 00000000..d331b04e
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift
@@ -0,0 +1,45 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+import XCTest
+
+@testable import TFLBertNLClassifier
+
+class TFLBertNLClassifierTest: XCTestCase {
+
+ static let bundle = Bundle(for: TFLBertNLClassifierTest.self)
+ static let bertModelPath = bundle.path(forResource: "test_model_nl_classifier_bert", ofType: "tflite")!
+
+ func testClassifyPositiveResult() {
+ let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier(
+ modelPath: TFLBertNLClassifierTest.bertModelPath)
+
+ XCTAssertNotNil(bertNLClassifier)
+
+ let categories = bertNLClassifier.classify(text: "it's a charming and often affecting journey")
+
+ XCTAssertGreaterThan(categories["positive"]!.doubleValue, categories["negative"]!.doubleValue)
+ }
+
+ func testClassifyNegativeResult() {
+ let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier(
+ modelPath: TFLBertNLClassifierTest.bertModelPath)
+
+ XCTAssertNotNil(bertNLClassifier)
+
+ let categories = bertNLClassifier.classify(text: "unflinchingly bleak and desperate")
+
+ XCTAssertGreaterThan(categories["negative"]!.doubleValue, categories["positive"]!.doubleValue)
+ }
+}
diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m
new file mode 100644
index 00000000..40814ac6
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m
@@ -0,0 +1,65 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h"
+
+#import <XCTest/XCTest.h>
+
+NS_ASSUME_NONNULL_BEGIN
+
+@interface TFLNLClassifierTest : XCTestCase
+@property(nonatomic, nullable) NSString *modelPath;
+@property(nonatomic, nullable) TFLNLClassifierOptions *modelOptions;
+@end
+
+@implementation TFLNLClassifierTest
+#pragma mark - Tests
+
+- (void)setUp {
+ [super setUp];
+ NSBundle *bundle = [NSBundle bundleForClass:[self class]];
+ self.modelPath = [bundle pathForResource:@"test_model_nl_classifier_with_regex_tokenizer"
+ ofType:@"tflite"];
+ self.modelOptions = [[TFLNLClassifierOptions alloc] init];
+ [self.modelOptions setInputTensorName:@"input_text"];
+ [self.modelOptions setOutputScoreTensorName:@"probability"];
+}
+
+- (void)testClassifyPositiveResult {
+ TFLNLClassifier *nlClassifier = [TFLNLClassifier nlClassifierWithModelPath:self.modelPath
+ options:self.modelOptions];
+
+ XCTAssertNotNil(nlClassifier);
+
+ NSDictionary<NSString *, NSNumber *> *categories = [nlClassifier
+ classifyWithText:@"This is the best movie I’ve seen in recent years. Strongly recommend it!"];
+
+ XCTAssertGreaterThan([categories[@"Positive"] doubleValue],
+ [categories[@"Negative"] doubleValue]);
+}
+
+- (void)testClassifyNegativeResult {
+ TFLNLClassifier *nlClassifier = [TFLNLClassifier nlClassifierWithModelPath:self.modelPath
+ options:self.modelOptions];
+
+ XCTAssertNotNil(nlClassifier);
+
+ NSDictionary<NSString *, NSNumber *> *categories =
+ [nlClassifier classifyWithText:@"What a waste of my time."];
+
+ XCTAssertGreaterThan([categories[@"Negative"] doubleValue],
+ [categories[@"Positive"] doubleValue]);
+}
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.swift b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.swift
new file mode 100644
index 00000000..fb80e5da
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.swift
@@ -0,0 +1,58 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+import XCTest
+
+@testable import TFLNLClassifier
+
+class TFLNLClassifierTest: XCTestCase {
+
+ static let bundle = Bundle(for: TFLNLClassifierTest.self)
+ static let modelPath = bundle.path(
+ forResource: "test_model_nl_classifier_with_regex_tokenizer",
+ ofType: "tflite")!
+
+ var modelOptions:TFLNLClassifierOptions!;
+
+ override func setUp() {
+ modelOptions = TFLNLClassifierOptions()
+ modelOptions.inputTensorName = "input_text"
+ modelOptions.outputScoreTensorName = "probability"
+ }
+
+ func testClassifyPositiveResult() {
+ let nlClassifier = TFLNLClassifier.nlClassifier(
+ modelPath: TFLNLClassifierTest.modelPath,
+ options: modelOptions)
+
+ XCTAssertNotNil(nlClassifier)
+
+ let categories = nlClassifier.classify(
+ text: "This is the best movie I’ve seen in recent years. Strongly recommend it!")
+
+ XCTAssertGreaterThan(categories["Positive"]!.doubleValue, categories["Negative"]!.doubleValue)
+ }
+
+ func testClassifyNegativeResult() {
+ let nlClassifier = TFLNLClassifier.nlClassifier(
+ modelPath: TFLNLClassifierTest.modelPath,
+ options: modelOptions)
+
+ XCTAssertNotNil(nlClassifier)
+
+ let categories = nlClassifier.classify(text: "What a waste of my time.")
+
+ XCTAssertGreaterThan(categories["Negative"]!.doubleValue, categories["Positive"]!.doubleValue)
+ }
+}
diff --git a/tensorflow_lite_support/ios/task/text/qa/BUILD b/tensorflow_lite_support/ios/task/text/qa/BUILD
new file mode 100644
index 00000000..7998a8e9
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/qa/BUILD
@@ -0,0 +1,71 @@
+load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
+load("@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION")
+load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
+load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+objc_library(
+ name = "TFLBertQuestionAnswerer",
+ srcs = ["Sources/TFLBertQuestionAnswerer.m"],
+ hdrs = ["Sources/TFLBertQuestionAnswerer.h"],
+ module_name = "TFLBertQuestionAnswerer",
+ deps = [
+ "//tensorflow_lite_support/cc/task/text/qa:bert_qa_c_api",
+ "@google_toolbox_for_mac//:GTM_Defines",
+ ],
+)
+
+swift_library(
+ name = "TFLBertQuestionAnswererSwiftTestLibrary",
+ testonly = 1,
+ srcs = glob(["Tests/*.swift"]),
+ data = [
+ "//tensorflow_lite_support/cc/test/testdata/task/text:albert_model",
+ "//tensorflow_lite_support/cc/test/testdata/task/text:mobile_bert_model",
+ ],
+ tags = TFL_DEFAULT_TAGS,
+ deps = [
+ ":TFLBertQuestionAnswerer",
+ "//third_party/swift/xctest",
+ ],
+)
+
+ios_unit_test(
+ name = "TFLBertQuestionAnswererSwiftTest",
+ size = "large",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ runner = tflite_ios_lab_runner("IOS_LATEST"),
+ tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
+ deps = [
+ ":TFLBertQuestionAnswererSwiftTestLibrary",
+ ],
+)
+
+objc_library(
+ name = "TFLBertQuestionAnswererObjcTestLibrary",
+ testonly = 1,
+ srcs = glob(["Tests/*.m"]),
+ data = [
+ "//tensorflow_lite_support/cc/test/testdata/task/text:albert_model",
+ "//tensorflow_lite_support/cc/test/testdata/task/text:mobile_bert_model",
+ ],
+ tags = TFL_DEFAULT_TAGS,
+ deps = [
+ ":TFLBertQuestionAnswerer",
+ ],
+)
+
+ios_unit_test(
+ name = "TFLBertQuestionAnswererObjcTest",
+ size = "large",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ runner = tflite_ios_lab_runner("IOS_LATEST"),
+ tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
+ deps = [
+ ":TFLBertQuestionAnswererObjcTestLibrary",
+ ],
+)
diff --git a/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h b/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
new file mode 100644
index 00000000..57b7c69c
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h
@@ -0,0 +1,74 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import <Foundation/Foundation.h>
+
+NS_ASSUME_NONNULL_BEGIN
+/**
+ * Struct to represent the logit and offset of the answer related to context.
+ */
+struct TFLPos {
+ int start;
+ int end;
+ float logit;
+};
+
+/**
+ * Class for the Answer to BertQuestionAnswerer.
+ */
+@interface TFLQAAnswer : NSObject
+@property(nonatomic) struct TFLPos pos;
+@property(nonatomic) NSString* text;
+@end
+
+/**
+ * BertQA task API, performs tokenization for models (BERT, Albert, etc.) in
+ * preprocess and returns most possible answers.
+ *
+ * In particular, the branch of BERT models use WordPiece tokenizer, and the
+ * branch of Albert models use SentencePiece tokenizer, respectively.
+ */
+@interface TFLBertQuestionAnswerer : NSObject
+
+/**
+ * Creates a BertQuestionAnswerer instance with an albert model or mobilebert
+ * model. The API expects a Bert based TFLite model with metadata containing
+ * the following information:
+ * input_process_units: for Wordpiece/Sentencepiece Tokenizer
+ * 3 input tensors with names "ids", "mask" and "segment_ids"
+ * 2 output tensors with names "end_logits" and "start_logits"
+ * Sample models:
+ * https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
+ * https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
+ * @param modelPath The file path to the tflite model.
+ * @return A BertQuestionAnswerer instance.
+ */
++ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath
+ NS_SWIFT_NAME(questionAnswerer(modelPath:));
+
+/**
+ * Answers question based on the context. Could be empty if no answer was found
+ * from the given context.
+ *
+ * @param context Context the question bases on.
+ * @param question Question to ask.
+ *
+ * @return A list of answers to the question, reversely sorted by the
+ * probability of each answer.
+ */
+- (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
+ question:(NSString*)question
+ NS_SWIFT_NAME(answer(context:question:));
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m b/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m
new file mode 100644
index 00000000..fc1bd08b
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m
@@ -0,0 +1,71 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h"
+#import "GTMDefines.h"
+#include "tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h"
+
+NS_ASSUME_NONNULL_BEGIN
+
+@implementation TFLQAAnswer
+@synthesize pos;
+@synthesize text;
+@end
+
+@interface TFLBertQuestionAnswerer()
+/** BertQuestionAnswerer backed by C API */
+@property(nonatomic) BertQuestionAnswerer *bertQuestionAnswerer;
+@end
+
+@implementation TFLBertQuestionAnswerer
+
+- (void)dealloc {
+ BertQuestionAnswererDelete(_bertQuestionAnswerer);
+}
+
++ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath {
+ BertQuestionAnswerer* bert_qa = BertQuestionAnswererFromFile(modelPath.UTF8String);
+
+ _GTMDevAssert(bert_qa, @"Failed to create BertQuestionAnswerer");
+ return [[TFLBertQuestionAnswerer alloc]
+ initWithBertQuestionAnswerer:bert_qa];
+}
+
+- (instancetype)initWithBertQuestionAnswerer:(BertQuestionAnswerer *)bertQuestionAnswerer {
+ self = [super init];
+ if (self) {
+ _bertQuestionAnswerer = bertQuestionAnswerer;
+ }
+ return self;
+}
+
+- (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
+ struct QaAnswers *cAnswers =
+ BertQuestionAnswererAnswer(_bertQuestionAnswerer, context.UTF8String, question.UTF8String);
+ NSMutableArray<TFLQAAnswer *> *ret = [NSMutableArray arrayWithCapacity:cAnswers->size];
+ for (int i = 0; i < cAnswers->size; i++) {
+ struct QaAnswer cAnswer = cAnswers->answers[i];
+ TFLQAAnswer *answer = [[TFLQAAnswer alloc] init];
+ struct TFLPos pos = {.start = cAnswer.start,
+ .end = cAnswer.end,
+ .logit = cAnswer.logit};
+ [answer setPos:pos];
+ [answer setText:[NSString stringWithUTF8String:cAnswer.text]];
+ [ret addObject:answer];
+ }
+ BertQuestionAnswererQaAnswersDelete(cAnswers);
+ return ret;
+}
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m b/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m
new file mode 100644
index 00000000..90610630
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m
@@ -0,0 +1,72 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h"
+
+#import <XCTest/XCTest.h>
+
+static NSString *const kContext =
+ @"The role of teacher is often formal and ongoing, carried out at a school "
+ "or other place of formal education. In many countries, a person who "
+ "wishes to become a teacher must first obtain specified professional "
+ "qualifications or credentials from a university or college. These "
+ "professional qualifications may include the study of pedagogy, the "
+ "science of teaching. Teachers, like other professionals, may have to "
+ "continue their education after they qualify, a process known as "
+ "continuing professional development. Teachers may use a lesson plan to "
+ "facilitate student learning, providing a course of study which is called "
+ "the curriculum.";
+static NSString *const kQuestion = @"What is a course of study called?";
+static NSString *const kAnswer = @"the curriculum.";
+
+@interface TFLBertQuestionAnswererTest : XCTestCase
+@property(nonatomic, nullable) NSString *mobileBertModelPath;
+@property(nonatomic, nullable) NSString *albertModelPath;
+@end
+
+@implementation TFLBertQuestionAnswererTest
+#pragma mark - Tests
+
+- (void)setUp {
+ [super setUp];
+ NSBundle *bundle = [NSBundle bundleForClass:[self class]];
+ self.mobileBertModelPath = [bundle pathForResource:@"mobilebert_with_metadata" ofType:@"tflite"];
+ self.albertModelPath = [bundle pathForResource:@"albert_with_metadata" ofType:@"tflite"];
+}
+
+- (void)testInitMobileBert {
+ TFLBertQuestionAnswerer* mobileBertAnswerer =
+ [TFLBertQuestionAnswerer questionAnswererWithModelPath:self.mobileBertModelPath];
+
+ XCTAssertNotNil(mobileBertAnswerer);
+
+ NSArray<TFLQAAnswer*>* answers =
+ [mobileBertAnswerer answerWithContext:kContext question:kQuestion];
+
+ XCTAssertEqualObjects([answers[0] text], kAnswer);
+}
+
+- (void)testInitAlbert {
+ TFLBertQuestionAnswerer* albertAnswerer =
+ [TFLBertQuestionAnswerer questionAnswererWithModelPath:self.albertModelPath];
+
+ XCTAssertNotNil(albertAnswerer);
+
+ NSArray<TFLQAAnswer*>* answers =
+ [albertAnswerer answerWithContext:kContext question:kQuestion];
+
+
+ XCTAssertEqualObjects([answers[0] text], kAnswer);
+}
+@end
diff --git a/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.swift b/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.swift
new file mode 100644
index 00000000..3f15cc52
--- /dev/null
+++ b/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.swift
@@ -0,0 +1,63 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+import XCTest
+
+@testable import TFLBertQuestionAnswerer
+
+class TFLBertQuestionAnswererTest: XCTestCase {
+
+ static let bundle = Bundle(for: TFLBertQuestionAnswererTest.self)
+ static let mobileBertModelPath = bundle.path(forResource: "mobilebert_with_metadata", ofType: "tflite")!
+
+ static let albertModelPath = bundle.path(forResource: "albert_with_metadata", ofType: "tflite")!
+
+ static let context = """
+ The role of teacher is often formal and ongoing, carried out at a school or other place of
+ formal education. In many countries, a person who wishes to become a teacher must first obtain
+ specified professional qualifications or credentials from a university or college. These
+ professional qualifications may include the study of pedagogy, the science of teaching.
+ Teachers, like other professionals, may have to continue their education after they qualify,
+ a process known as continuing professional development. Teachers may use a lesson plan to
+ facilitate student learning, providing a course of study which is called the curriculum.
+ """
+ static let question = "What is a course of study called?"
+ static let answer = "the curriculum."
+
+ func testInitMobileBert() {
+ let mobileBertAnswerer = TFLBertQuestionAnswerer.questionAnswerer(
+ modelPath: TFLBertQuestionAnswererTest.mobileBertModelPath)
+
+ XCTAssertNotNil(mobileBertAnswerer)
+
+ let answers = mobileBertAnswerer.answer(
+ context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
+
+ XCTAssertNotNil(answers)
+ XCTAssertEqual(answers[0].text, TFLBertQuestionAnswererTest.answer)
+ }
+
+ func testInitAlbert() {
+ let albertAnswerer = TFLBertQuestionAnswerer.questionAnswerer(
+ modelPath: TFLBertQuestionAnswererTest.albertModelPath)
+
+ XCTAssertNotNil(albertAnswerer)
+
+ let answers = albertAnswerer.answer(
+ context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
+
+ XCTAssertNotNil(answers)
+ XCTAssertEqual(answers[0].text, TFLBertQuestionAnswererTest.answer)
+ }
+}
diff --git a/tensorflow_lite_support/ios/text/tokenizers/BUILD b/tensorflow_lite_support/ios/text/tokenizers/BUILD
new file mode 100644
index 00000000..34ba9c6b
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/BUILD
@@ -0,0 +1,106 @@
+load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")
+load("@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION")
+load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
+load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+objc_library(
+ name = "TFLTokenizerUtil",
+ srcs = [
+ "Sources/TFLTokenizerUtil.mm",
+ ],
+ hdrs = [
+ "Sources/TFLTokenizerUtil.h",
+ ],
+ module_name = "TFLTokenizerUtil",
+ deps = [
+ "//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
+ "//tensorflow_lite_support/ios/utils:TFLStringUtil",
+ ],
+)
+
+objc_library(
+ name = "TFLBertTokenizer",
+ srcs = [
+ "Sources/TFLBertTokenizer.mm",
+ ],
+ hdrs = [
+ "Sources/TFLBertTokenizer.h",
+ "Sources/TFLTokenizer.h",
+ ],
+ module_name = "TFLBertTokenizer",
+ deps = [
+ ":TFLTokenizerUtil",
+ "//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer",
+ "//tensorflow_lite_support/ios/utils:TFLStringUtil",
+ ],
+)
+
+swift_library(
+ name = "TFLBertTokenizerTestLibrary",
+ testonly = 1,
+ srcs = ["Tests/TFLBertTokenizerTest.swift"],
+ data = [
+ "//tensorflow_lite_support/cc/test/testdata/task/text:mobile_bert_model",
+ ],
+ tags = TFL_DEFAULT_TAGS,
+ deps = [
+ ":TFLBertTokenizer",
+ "//third_party/swift/xctest",
+ ],
+)
+
+ios_unit_test(
+ name = "TFLBertTokenizerTest",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ runner = tflite_ios_lab_runner("IOS_LATEST"),
+ tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
+ deps = [
+ ":TFLBertTokenizerTestLibrary",
+ ],
+)
+
+objc_library(
+ name = "TFLSentencepieceTokenizer",
+ srcs = [
+ "Sources/TFLSentencepieceTokenizer.mm",
+ ],
+ hdrs = [
+ "Sources/TFLSentencepieceTokenizer.h",
+ "Sources/TFLTokenizer.h",
+ ],
+ module_name = "TFLSentencepieceTokenizer",
+ deps = [
+ ":TFLTokenizerUtil",
+ "//tensorflow_lite_support/cc/text/tokenizers:sentencepiece_tokenizer",
+ "//tensorflow_lite_support/ios/utils:TFLStringUtil",
+ ],
+)
+
+swift_library(
+ name = "TFLSentencepieceTokenizerTestLibrary",
+ testonly = 1,
+ srcs = ["Tests/TFLSentencepieceTokenizerTest.swift"],
+ data = [
+ "//tensorflow_lite_support/cc/test/testdata/task/text:albert_model",
+ ],
+ tags = TFL_DEFAULT_TAGS,
+ deps = [
+ ":TFLSentencepieceTokenizer",
+ "//third_party/swift/xctest",
+ ],
+)
+
+ios_unit_test(
+ name = "TFLSentencepieceTokenizerTest",
+ minimum_os_version = TFL_MINIMUM_OS_VERSION,
+ runner = tflite_ios_lab_runner("IOS_LATEST"),
+ tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
+ deps = [
+ ":TFLSentencepieceTokenizerTestLibrary",
+ ],
+)
diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
new file mode 100644
index 00000000..aa692489
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h"
+
+NS_ASSUME_NONNULL_BEGIN
+/**
+ * Wordpiece Tokenizer implemenation.
+ */
+@interface TFLBertTokenizer : NSObject <TFLTokenizer>
+
+/**
+ * Default initializer is not available.
+ */
+- (instancetype)init NS_UNAVAILABLE;
+
+/**
+ * Initializes the tokenizer with the path to wordpiece vocabulary file.
+ */
+- (instancetype)initWithVocabPath:(NSString *)vocabPath NS_DESIGNATED_INITIALIZER;
+
+/**
+ * Initializes the tokenizer with a list of tokens.
+ */
+- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab NS_DESIGNATED_INITIALIZER;
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm
new file mode 100644
index 00000000..949cef2b
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm
@@ -0,0 +1,57 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h"
+#include "third_party/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
+#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h"
+#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
+
+NS_ASSUME_NONNULL_BEGIN
+using BertTokenizerCPP = ::tflite::support::text::tokenizer::BertTokenizer;
+
+@implementation TFLBertTokenizer {
+ std::unique_ptr<BertTokenizerCPP> _bertTokenizer;
+}
+
+- (instancetype)initWithVocabPath:(NSString *)vocabPath {
+ self = [super init];
+ if (self) {
+ _bertTokenizer = absl::make_unique<BertTokenizerCPP>(MakeString(vocabPath));
+ }
+ return self;
+}
+
+- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab {
+ self = [super init];
+ if (self) {
+ std::vector<std::string> vocabCpp;
+ vocabCpp.reserve([vocab count]);
+ for (NSString *word in vocab) {
+ vocabCpp.emplace_back(MakeString(word));
+ }
+ _bertTokenizer = absl::make_unique<BertTokenizerCPP>(vocabCpp);
+ }
+ return self;
+}
+
+- (NSArray<NSString *> *)tokensFromInput:(NSString *)input {
+ return Tokenize(_bertTokenizer.get(), input);
+}
+
+- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens {
+ return ConvertTokensToIds(_bertTokenizer.get(), tokens);
+}
+
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
new file mode 100644
index 00000000..eef3bf1e
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h
@@ -0,0 +1,33 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h"
+
+NS_ASSUME_NONNULL_BEGIN
+/**
+ * Sentencepiece Tokenizer implemenation.
+ */
+@interface TFLSentencepieceTokenizer : NSObject <TFLTokenizer>
+
+/**
+ * Default initializer is not available.
+ */
+- (instancetype)init NS_UNAVAILABLE;
+
+/**
+ * Initializes the tokenizer with the path to sentencepiece model file.
+ */
+- (instancetype)initWithModelPath:(NSString *)modelPath;
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm
new file mode 100644
index 00000000..1e21cee5
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm
@@ -0,0 +1,45 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h"
+#include "third_party/absl/memory/memory.h"
+#include "third_party/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h"
+#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h"
+#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
+
+NS_ASSUME_NONNULL_BEGIN
+using SentencepieceTokenizerCPP = ::tflite::support::text::tokenizer::SentencePieceTokenizer;
+
+@implementation TFLSentencepieceTokenizer {
+ std::unique_ptr<SentencepieceTokenizerCPP> _spTokenizer;
+}
+
+- (instancetype)initWithModelPath:(NSString *)modelPath {
+ self = [super init];
+ if (self) {
+ _spTokenizer = absl::make_unique<SentencepieceTokenizerCPP>(MakeString(modelPath));
+ }
+ return self;
+}
+
+- (NSArray<NSString *> *)tokensFromInput:(NSString *)input {
+ return Tokenize(_spTokenizer.get(), input);
+}
+
+- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens {
+ return ConvertTokensToIds(_spTokenizer.get(), tokens);
+}
+
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
new file mode 100644
index 00000000..ee0972f8
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h
@@ -0,0 +1,39 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import <Foundation/Foundation.h>
+
+NS_ASSUME_NONNULL_BEGIN
+/**
+ * Protocol for a Tokenizer used in model proprocessing.
+ */
+@protocol TFLTokenizer
+
+/**
+ * Performs tokenization on input text.
+ * @param input The input string to be tokenized.
+ *
+ * @return A list of tokens.
+ */
+- (NSArray<NSString *> *)tokensFromInput:(NSString *)input;
+
+/*
+ * Convert a list of tokens back to their coressponding IDs.
+ * @param tokens The tokens to be converted.
+ *
+ * @return A list of ids.
+ */
+- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens;
+@end
+NS_ASSUME_NONNULL_END
diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
new file mode 100644
index 00000000..574b5553
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import <Foundation/Foundation.h>
+#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
+
+using ::tflite::support::text::tokenizer::Tokenizer;
+
+/**
+ * Invokes the cpp tokenizer's tokenize function and converts input/output to objc.
+ *
+ * @param tokenizer The cpp tokenizer pointer.
+ * @param input The input string to be tokenized.
+ *
+ * @return A list of tokens.
+ */
+NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input);
+
+/**
+ * Invokes the cpp tokenizer's convertTokensToIds function and converts input/output to objc.
+ *
+ * @param tokenizer The cpp tokenizer pointer.
+ * @param input The tokens to be converted.
+ *
+ * @return A list of ids.
+ */
+NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens);
diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm
new file mode 100644
index 00000000..52180578
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm
@@ -0,0 +1,41 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h"
+
+#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
+
+using ::tflite::support::text::tokenizer::TokenizerResult;
+
+NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input) {
+ TokenizerResult tokenize_result = tokenizer->Tokenize(MakeString(input));
+ std::vector<std::string> subwords = tokenize_result.subwords;
+ NSMutableArray<NSString *> *ret = [NSMutableArray arrayWithCapacity:subwords.size()];
+ for (int i = 0; i < subwords.size(); ++i) {
+ [ret addObject:MakeNSString(subwords[i])];
+ }
+ return ret;
+}
+
+NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens) {
+ NSMutableArray<NSNumber *> *ret = [NSMutableArray arrayWithCapacity:[tokens count]];
+ for (NSString *token in tokens) {
+ std::string cc_token = MakeString(token);
+ const char *cToken = cc_token.c_str();
+ int id;
+ tokenizer->LookupId(cToken, &id);
+ [ret addObject:[NSNumber numberWithInt:id]];
+ }
+ return ret;
+}
diff --git a/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift b/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift
new file mode 100644
index 00000000..e805f301
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift
@@ -0,0 +1,50 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+import XCTest
+
+@testable import TFLBertTokenizer
+
+class TFLBertTokenizerTest: XCTestCase {
+ static let bundle = Bundle(for: TFLBertTokenizerTest.self)
+ static let mobileBertVocabPath = bundle.path(forResource: "vocab", ofType: "txt")!
+
+ func testInitBertTokenizerFromPath() {
+ let bertTokenizer = TFLBertTokenizer(vocabPath: TFLBertTokenizerTest.mobileBertVocabPath)
+
+ XCTAssertNotNil(bertTokenizer)
+
+ let tokens = bertTokenizer.tokens(fromInput: "i'm questionansweraskask")
+
+ XCTAssertEqual(tokens, ["i", "'", "m", "question", "##ans", "##wer", "##ask", "##ask"])
+
+ let ids = bertTokenizer.ids(fromTokens: tokens)
+
+ XCTAssertEqual(ids, [1045, 1005, 1049, 3160, 6962, 13777, 19895, 19895])
+ }
+
+ func testInitBertTokenizerFromVocab() {
+ let bertTokenizer = TFLBertTokenizer(vocab: ["hell", "##o", "wor", "##ld", "there"])
+
+ XCTAssertNotNil(bertTokenizer)
+
+ let tokens = bertTokenizer.tokens(fromInput: "hello there hello world")
+
+ XCTAssertEqual(tokens, ["hell", "##o", "there", "hell", "##o", "wor", "##ld"])
+
+ let ids = bertTokenizer.ids(fromTokens: tokens)
+
+ XCTAssertEqual(ids, [0, 1, 4, 0, 1, 2, 3])
+ }
+}
diff --git a/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLSentencepieceTokenizerTest.swift b/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLSentencepieceTokenizerTest.swift
new file mode 100644
index 00000000..c7c6d1e2
--- /dev/null
+++ b/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLSentencepieceTokenizerTest.swift
@@ -0,0 +1,37 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+import XCTest
+
+@testable import TFLSentencepieceTokenizer
+
+class TFLSentencepieceTokenizerTest: XCTestCase {
+ static let bundle = Bundle(for: TFLSentencepieceTokenizerTest.self)
+ static let spModelPath = bundle.path(forResource: "30k-clean", ofType: "model")!
+
+ func testInitSentenpieceTokenizerFromPath() {
+ let spTokenizer = TFLSentencepieceTokenizer(
+ modelPath: TFLSentencepieceTokenizerTest.spModelPath)
+
+ XCTAssertNotNil(spTokenizer)
+
+ let tokens = spTokenizer.tokens(fromInput: "good morning, i'm your teacher.\n")
+
+ XCTAssertEqual(tokens, ["▁good", "▁morning", ",", "▁i", "'", "m", "▁your", "▁teacher", "."])
+
+ let ids = spTokenizer.ids(fromTokens: tokens)
+
+ XCTAssertEqual(ids, [254, 959, 15, 31, 22, 79, 154, 2197, 9])
+ }
+}
diff --git a/tensorflow_lite_support/ios/utils/BUILD b/tensorflow_lite_support/ios/utils/BUILD
new file mode 100644
index 00000000..63f10915
--- /dev/null
+++ b/tensorflow_lite_support/ios/utils/BUILD
@@ -0,0 +1,15 @@
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+objc_library(
+ name = "TFLStringUtil",
+ srcs = [
+ "Sources/TFLStringUtil.mm",
+ ],
+ hdrs = [
+ "Sources/TFLStringUtil.h",
+ ],
+ module_name = "TFLStringUtil",
+)
diff --git a/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h b/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h
new file mode 100644
index 00000000..3ea091f5
--- /dev/null
+++ b/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h
@@ -0,0 +1,23 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import <Foundation/Foundation.h>
+
+#include <string>
+
+// Translates a NSString encoded in UTF-8 to a std::string.
+std::string MakeString(NSString*);
+
+// Translates a std::string to the equivalent NSString by making a copy.
+NSString* MakeNSString(const std::string&);
diff --git a/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
new file mode 100644
index 00000000..6e9cf238
--- /dev/null
+++ b/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm
@@ -0,0 +1,23 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h"
+
+std::string MakeString(NSString* str) { return std::string([str UTF8String]); }
+
+NSString* MakeNSString(const std::string& str) {
+ return [[NSString alloc] initWithBytes:const_cast<void*>(static_cast<const void*>(str.data()))
+ length:str.length()
+ encoding:NSUTF8StringEncoding];
+}
diff --git a/tensorflow_lite_support/java/AndroidManifest.xml b/tensorflow_lite_support/java/AndroidManifest.xml
new file mode 100644
index 00000000..14909296
--- /dev/null
+++ b/tensorflow_lite_support/java/AndroidManifest.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="org.tensorflow.lite.support">
+ <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/>
+</manifest>
diff --git a/tensorflow_lite_support/java/BUILD b/tensorflow_lite_support/java/BUILD
new file mode 100644
index 00000000..090b1add
--- /dev/null
+++ b/tensorflow_lite_support/java/BUILD
@@ -0,0 +1,78 @@
+# Description:
+# TensorFlow Lite Support API in Java.
+
+load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "AndroidManifest.xml",
+ "default_version_script.lds",
+ "debug_version_script.lds",
+])
+
+# Android Library target for TFLite Support Library. It depends on TensorFlow
+# Lite runtime (tensorflow/lite/java:tensorflowlite). If you don't want to
+# introduce the native library into dependencies, use
+# "tensorflowlite_support_java" instead, which depends on
+# tensorflow/lite/java:tensorflowlite_java.
+android_library(
+ name = "tensorflowlite_support",
+ srcs = glob(
+ ["src/java/org/tensorflow/lite/support/**/*.java"],
+ ),
+ javacopts = JAVACOPTS,
+ manifest = "AndroidManifest.xml",
+ deps = [
+ "@org_checkerframework_qual",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite",
+ ],
+)
+
+android_library(
+ name = "tensorflowlite_support_java",
+ srcs = glob(
+ ["src/java/org/tensorflow/lite/support/**/*.java"],
+ ),
+ javacopts = JAVACOPTS,
+ manifest = "AndroidManifest.xml",
+ deps = [
+ "@org_checkerframework_qual",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
+ ],
+)
+
+# TODO(b/156482505): Remove this target.
+alias(
+ name = "tensorflow-lite-support-nogpu",
+ actual = ":tensorflow-lite-support",
+)
+
+# This alias matches the associated .aar library name output style.
+alias(
+ name = "tensorflow-lite-support",
+ actual = ":tensorflowlite_support",
+)
+
+java_library(
+ name = "tensorflowlite_support_precondition_lib",
+ srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"],
+ javacopts = JAVACOPTS,
+ deps = [
+ "@org_checkerframework_qual",
+ ],
+)
+
+android_library(
+ name = "tensorflowlite_support_precondition",
+ srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"],
+ javacopts = JAVACOPTS,
+ manifest = "AndroidManifest.xml",
+ deps = [
+ "@org_checkerframework_qual",
+ ],
+)
diff --git a/tensorflow_lite_support/java/README.md b/tensorflow_lite_support/java/README.md
new file mode 100644
index 00000000..8d37bf8b
--- /dev/null
+++ b/tensorflow_lite_support/java/README.md
@@ -0,0 +1,38 @@
+# TensorFlow Lite Support
+
+TensorFlow Lite Support contains a set of tools and libraries that help
+developing ML with TFLite for mobile apps. See the [documentation on
+tensorflow.org](https://www.tensorflow.org/lite/inference_with_metadata/overview)
+for more information about all the efforts under TensorFlow Lite Support.
+
+This directory contains the Java code for the TensorFlow Lite SupportLibrary
+and TensorFlow Lite Task Library.
+
+## TensorFlow Lite Android Support Library
+
+Mobile application developers typically interact with typed objects such as
+bitmaps or primitives such as integers. However, the TensorFlow Lite Interpreter
+that runs the on-device machine learning model uses tensors in the form of
+ByteBuffer, which can be difficult to debug and manipulate. The TensorFlow Lite
+Android Support Library is designed to help process the input and output of
+TensorFlow Lite models, and make the TensorFlow Lite interpreter easier to use.
+
+We welcome feedback from the community as we develop this support library,
+especially around:
+
+* Use-cases we should support including data types and operations
+* Ease of use - does the APIs make sense to the community
+
+See the [documentation](https://www.tensorflow.org/lite/inference_with_metadata/lite_support)
+for more instruction and examples.
+
+
+## TensorFlow Lite Android Task Library
+TensorFlow Lite Task Library provides optimized ready-to-use model interfaces
+for popular machine learning tasks, such as image classification, question and
+answer, etc. The model interfaces are specifically designed for each task to
+achieve the best performance and usability. Task Library works cross-platform
+and is supported on Java, C++, and Swift.
+
+See the [documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview)
+for more instruction and examples.
diff --git a/tensorflow_lite_support/java/debug_version_script.lds b/tensorflow_lite_support/java/debug_version_script.lds
new file mode 100644
index 00000000..53553a42
--- /dev/null
+++ b/tensorflow_lite_support/java/debug_version_script.lds
@@ -0,0 +1,5 @@
+VERS_1.0 {
+ # Export everything for debug purpose.
+ global:
+ *;
+};
diff --git a/tensorflow_lite_support/java/default_version_script.lds b/tensorflow_lite_support/java/default_version_script.lds
new file mode 100644
index 00000000..46bbffe7
--- /dev/null
+++ b/tensorflow_lite_support/java/default_version_script.lds
@@ -0,0 +1,12 @@
+VERS_1.0 {
+ # Export JNI and native C symbols.
+ global:
+ Java_*;
+ JNI_OnLoad;
+ JNI_OnUnload;
+ TfLite*;
+
+ # Hide everything else.
+ local:
+ *;
+};
diff --git a/tensorflow_lite_support/java/jni/BUILD b/tensorflow_lite_support/java/jni/BUILD
new file mode 100644
index 00000000..2c01b50d
--- /dev/null
+++ b/tensorflow_lite_support/java/jni/BUILD
@@ -0,0 +1,48 @@
+package(default_visibility = ["//tensorflow_lite_support:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+# Helper target for exposing JNI headers across multiple platforms.
+cc_library(
+ name = "jni",
+ hdrs = select({
+ # The Android toolchain makes "jni.h" available in the include path.
+ # For non-Android toolchains, generate jni.h and jni_md.h.
+ "//tensorflow_lite_support:android": [],
+ "//conditions:default": [
+ ":jni.h",
+ ":jni_md.h",
+ ],
+ }),
+ includes = select({
+ "//tensorflow_lite_support:android": [],
+ "//conditions:default": ["."],
+ }),
+ visibility = ["//visibility:public"],
+)
+
+# Silly rules to make
+# #include <jni.h>
+# in the source headers work
+# (in combination with the "includes" attribute of the tf_cuda_library rule
+# above. Not needed when using the Android toolchain).
+#
+# Inspired from:
+# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD
+# but hopefully there is a simpler alternative to this.
+genrule(
+ name = "copy_jni_h",
+ srcs = ["@bazel_tools//tools/jdk:jni_header"],
+ outs = ["jni.h"],
+ cmd = "cp -f $< $@",
+)
+
+genrule(
+ name = "copy_jni_md_h",
+ srcs = select({
+ "//tensorflow_lite_support:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"],
+ "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"],
+ }),
+ outs = ["jni_md.h"],
+ cmd = "cp -f $< $@",
+)
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
new file mode 100644
index 00000000..e83fd403
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java
@@ -0,0 +1,184 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common;
+
+import android.content.Context;
+import android.content.res.AssetFileDescriptor;
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.List;
+import org.checkerframework.checker.nullness.qual.NonNull;
+
+/** File I/O utilities. */
+public class FileUtil {
+ private FileUtil() {}
+
+ /**
+ * Loads labels from the label file into a list of strings.
+ *
+ * <p>A legal label file is the plain text file whose contents are split into lines, and each line
+ * is an individual value. The file should be in assets of the context.
+ *
+ * @param context The context holds assets.
+ * @param filePath The path of the label file, relative with assets directory.
+ * @return a list of labels.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath)
+ throws IOException {
+ return loadLabels(context, filePath, Charset.defaultCharset());
+ }
+
+ /**
+ * Loads labels from the label file into a list of strings.
+ *
+ * <p>A legal label file is the plain text file whose contents are split into lines, and each line
+ * is an individual value. The empty lines will be ignored. The file should be in assets of the
+ * context.
+ *
+ * @param context The context holds assets.
+ * @param filePath The path of the label file, relative with assets directory.
+ * @param cs {@code Charset} to use when decoding content of label file.
+ * @return a list of labels.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadLabels(
+ @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
+ SupportPreconditions.checkNotNull(context, "Context cannot be null.");
+ SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
+ try (InputStream inputStream = context.getAssets().open(filePath)) {
+ return loadLabels(inputStream, cs);
+ }
+ }
+
+ /**
+ * Loads labels from an input stream of an opened label file. See details for label files in
+ * {@link FileUtil#loadLabels(Context, String)}.
+ *
+ * @param inputStream the input stream of an opened label file.
+ * @return a list of labels.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException {
+ return loadLabels(inputStream, Charset.defaultCharset());
+ }
+
+ /**
+ * Loads labels from an input stream of an opened label file. See details for label files in
+ * {@link FileUtil#loadLabels(Context, String)}.
+ *
+ * @param inputStream the input stream of an opened label file.
+ * @param cs {@code Charset} to use when decoding content of label file.
+ * @return a list of labels.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs)
+ throws IOException {
+ List<String> labels = new ArrayList<>();
+ try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ if (line.trim().length() > 0) {
+ labels.add(line);
+ }
+ }
+ return labels;
+ }
+ }
+
+ /**
+ * Loads a vocabulary file (a single-column text file) into a list of strings.
+ *
+ * <p>A vocabulary file is a single-column plain text file whose contents are split into lines,
+ * and each line is an individual value. The file should be in assets of the context.
+ *
+ * @param context The context holds assets.
+ * @param filePath The path of the vocabulary file, relative with assets directory.
+ * @return a list of vocabulary words.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadSingleColumnTextFile(
+ @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException {
+ return loadLabels(context, filePath, cs);
+ }
+
+ /**
+ * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column
+ * text file). See details for vocabulary files in {@link FileUtil#loadVocabularyFile(Context,
+ * String)}.
+ *
+ * @param inputStream the input stream of an opened vocabulary file.
+ * @return a list of vocabulary words.
+ * @throws IOException if error occurs to open or read the file.
+ */
+ @NonNull
+ public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs)
+ throws IOException {
+ return loadLabels(inputStream, cs);
+ }
+
+ /**
+ * Loads a file from the asset folder through memory mapping.
+ *
+ * @param context Application context to access assets.
+ * @param filePath Asset path of the file.
+ * @return the loaded memory mapped file.
+ * @throws IOException if an I/O error occurs when loading the tflite model.
+ */
+ @NonNull
+ public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath)
+ throws IOException {
+ SupportPreconditions.checkNotNull(context, "Context should not be null.");
+ SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
+ try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
+ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
+ }
+
+ /**
+ * Loads a binary file from the asset folder.
+ *
+ * @param context Application context to access assets.
+ * @param filePath Asset path of the file.
+ * @return the byte array for the binary file.
+ * @throws IOException if an I/O error occurs when loading file.
+ */
+ @NonNull
+ public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath)
+ throws IOException {
+ ByteBuffer buffer = loadMappedFile(context, filePath);
+ byte[] byteArray = new byte[buffer.remaining()];
+ buffer.get(byteArray);
+ return byteArray;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
new file mode 100644
index 00000000..38dfe881
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java
@@ -0,0 +1,31 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common;
+
+/**
+ * The common interface for classes that carries an "apply" method, which converts T to another one.
+ * @param <T> The class which Operator handles.
+ */
+public interface Operator<T> {
+
+ /**
+ * Applies an operation on a T object, returning a T object.
+ *
+ * <p>Note: The returned object could probably be the same one with given input, and given input
+ * could probably be changed.
+ */
+ T apply(T x);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
new file mode 100644
index 00000000..07d7e2bd
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java
@@ -0,0 +1,23 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common;
+
+/**
+ * Processes T object with prepared {@link Operator<T>}.
+ */
+public interface Processor<T> {
+ T process(T input);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
new file mode 100644
index 00000000..ff0c6406
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java
@@ -0,0 +1,82 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.checkerframework.checker.nullness.qual.NonNull;
+
+/**
+ * A processor base class that chains a serial of {@link Operator<T>} and executes them.
+ *
+ * <p>Typically, users could use its subclasses, e.g. {@link
+ * org.tensorflow.lite.support.image.ImageProcessor} rather than directly use this one.
+ *
+ * @param <T> The type that the Operator is handling.
+ */
+public class SequentialProcessor<T> implements Processor<T> {
+
+ /** List of operators added to this {@link SequentialProcessor}. */
+ protected final List<Operator<T>> operatorList;
+ /**
+ * The {@link Map} between the operator name and the corresponding op indexes in {@code
+ * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}.
+ */
+ protected final Map<String, List<Integer>> operatorIndex;
+
+ protected SequentialProcessor(Builder<T> builder) {
+ operatorList = builder.operatorList;
+ operatorIndex = Collections.unmodifiableMap(builder.operatorIndex);
+ }
+
+ @Override
+ public T process(T x) {
+ for (Operator<T> op : operatorList) {
+ x = op.apply(x);
+ }
+ return x;
+ }
+
+ /** The inner builder class to build a Sequential Processor. */
+ protected static class Builder<T> {
+
+ private final List<Operator<T>> operatorList;
+ private final Map<String, List<Integer>> operatorIndex;
+
+ protected Builder() {
+ operatorList = new ArrayList<>();
+ operatorIndex = new HashMap<>();
+ }
+
+ public Builder<T> add(@NonNull Operator<T> op) {
+ SupportPreconditions.checkNotNull(op, "Adding null Op is illegal.");
+ operatorList.add(op);
+ String operatorName = op.getClass().getName();
+ if (!operatorIndex.containsKey(operatorName)) {
+ operatorIndex.put(operatorName, new ArrayList<Integer>());
+ }
+ operatorIndex.get(operatorName).add(operatorList.size() - 1);
+ return this;
+ }
+
+ public SequentialProcessor<T> build() {
+ return new SequentialProcessor<T>(this);
+ }
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java
new file mode 100644
index 00000000..8620e13e
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java
@@ -0,0 +1,184 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common;
+
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/** Static error checking util methods. */
+public final class SupportPreconditions {
+ /**
+ * Ensures that an object reference passed as a parameter to the calling method is not null.
+ *
+ * @param reference an object reference
+ * @return the non-null reference that was validated
+ * @throws NullPointerException if {@code reference} is null
+ */
+ public static <T extends Object> T checkNotNull(T reference) {
+ if (reference == null) {
+ throw new NullPointerException("The object reference is null.");
+ }
+ return reference;
+ }
+
+ /**
+ * Ensures that an object reference passed as a parameter to the calling method is not null.
+ *
+ * @param reference an object reference
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}
+ * @return the non-null reference that was validated
+ * @throws NullPointerException if {@code reference} is null
+ */
+ public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) {
+ if (reference == null) {
+ throw new NullPointerException(String.valueOf(errorMessage));
+ }
+ return reference;
+ }
+
+ /**
+ * Ensures that the given String is not empty and not null.
+ *
+ * @param string the String to test
+ * @return the non-null non-empty String that was validated
+ * @throws IllegalArgumentException if {@code string} is null or empty
+ */
+ public static String checkNotEmpty(String string) {
+ if (string == null || string.length() == 0) {
+ throw new IllegalArgumentException("Given String is empty or null.");
+ }
+ return string;
+ }
+
+ /**
+ * Ensures that the given String is not empty and not null.
+ *
+ * @param string the String to test
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}
+ * @return the non-null non-empty String that was validated
+ * @throws IllegalArgumentException if {@code string} is null or empty
+ */
+ public static String checkNotEmpty(String string, Object errorMessage) {
+ if (string == null || string.length() == 0) {
+ throw new IllegalArgumentException(String.valueOf(errorMessage));
+ }
+ return string;
+ }
+
+ /**
+ * Ensures the truth of an expression involving one or more parameters to the calling method.
+ *
+ * @param expression a boolean expression.
+ * @throws IllegalArgumentException if {@code expression} is false.
+ */
+ public static void checkArgument(boolean expression) {
+ if (!expression) {
+ throw new IllegalArgumentException();
+ }
+ }
+
+ /**
+ * Ensures the truth of an expression involving one or more parameters to the calling method.
+ *
+ * @param expression a boolean expression.
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}.
+ * @throws IllegalArgumentException if {@code expression} is false.
+ */
+ public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
+ if (!expression) {
+ throw new IllegalArgumentException(String.valueOf(errorMessage));
+ }
+ }
+
+ /**
+ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
+ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
+ *
+ * @param index a user-supplied index identifying an element of an array, list or string
+ * @param size the size of that array, list or string
+ * @return the value of {@code index}
+ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
+ * @throws IllegalArgumentException if {@code size} is negative
+ */
+ public static int checkElementIndex(int index, int size) {
+ return checkElementIndex(index, size, "index");
+ }
+
+ /**
+ * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size
+ * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive.
+ *
+ * @param index a user-supplied index identifying an element of an array, list or string
+ * @param size the size of that array, list or string
+ * @param desc the text to use to describe this index in an error message
+ * @return the value of {@code index}
+ * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size}
+ * @throws IllegalArgumentException if {@code size} is negative
+ */
+ public static int checkElementIndex(int index, int size, @Nullable String desc) {
+ // Carefully optimized for execution by hotspot (explanatory comment above)
+ if (index < 0 || index >= size) {
+ throw new IndexOutOfBoundsException(badElementIndex(index, size, desc));
+ }
+ return index;
+ }
+
+ /**
+ * Ensures the truth of an expression involving the state of the calling instance, but not
+ * involving any parameters to the calling method.
+ *
+ * @param expression a boolean expression
+ * @throws IllegalStateException if {@code expression} is false
+ * @see Verify#verify Verify.verify()
+ */
+ public static void checkState(boolean expression) {
+ if (!expression) {
+ throw new IllegalStateException();
+ }
+ }
+
+ /**
+ * Ensures the truth of an expression involving the state of the calling instance, but not
+ * involving any parameters to the calling method.
+ *
+ * @param expression a boolean expression
+ * @param errorMessage the exception message to use if the check fails; will be converted to a
+ * string using {@link String#valueOf(Object)}
+ * @throws IllegalStateException if {@code expression} is false
+ * @see Verify#verify Verify.verify()
+ */
+ public static void checkState(boolean expression, @Nullable Object errorMessage) {
+ if (!expression) {
+ throw new IllegalStateException(String.valueOf(errorMessage));
+ }
+ }
+
+ private static String badElementIndex(int index, int size, @Nullable String desc) {
+ if (index < 0) {
+ return String.format("%s (%s) must not be negative", desc, index);
+ } else if (size < 0) {
+ throw new IllegalArgumentException("negative size: " + size);
+ } else { // index >= size
+ return String.format("%s (%s) must be less than size (%s)", desc, index, size);
+ }
+ }
+
+ private SupportPreconditions() {
+ throw new AssertionError("SupportPreconditions is Uninstantiable.");
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
new file mode 100644
index 00000000..d1b7021d
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java
@@ -0,0 +1,27 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common;
+
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * Applies some operation on TensorBuffers.
+ */
+public interface TensorOperator extends Operator<TensorBuffer> {
+ /** @see Operator#apply(Object) . */
+ @Override
+ TensorBuffer apply(TensorBuffer input);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
new file mode 100644
index 00000000..31531b2e
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java
@@ -0,0 +1,68 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common;
+
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * TensorProcessor is a helper class for preprocessing and postprocessing tensors. It could
+ * transform a {@link TensorBuffer} to another by executing a chain of {@link TensorOperator}.
+ *
+ * <p>Example Usage:
+ *
+ * <pre>
+ * TensorProcessor processor = new TensorProcessor.Builder().add(new NormalizeOp(1, 2)).build();
+ * TensorBuffer anotherTensorBuffer = processor.process(tensorBuffer);
+ * </pre>
+ *
+ * @see TensorProcessor.Builder to build a {@link TensorProcessor} instance.
+ * @see TensorProcessor#process(TensorBuffer) to apply the processor on a {@link TensorBuffer}.
+ */
+public class TensorProcessor extends SequentialProcessor<TensorBuffer> {
+ private TensorProcessor(Builder builder) {
+ super(builder);
+ }
+
+ /** The Builder to create an {@link TensorProcessor}, which could be executed later. */
+ public static class Builder extends SequentialProcessor.Builder<TensorBuffer> {
+
+ /**
+ * Creates a Builder to build {@link TensorProcessor}.
+ *
+ * @see #add(TensorOperator) to add an Op.
+ * @see #build() to complete the building process and get a built Processor.
+ */
+ public Builder() {
+ super();
+ }
+
+ /**
+ * Adds an {@link TensorOperator} into the Operator chain.
+ *
+ * @param op the Operator instance to be executed then.
+ */
+ public TensorProcessor.Builder add(TensorOperator op) {
+ super.add(op);
+ return this;
+ }
+
+ /** Completes the building process and gets the {@link TensorProcessor} instance. */
+ @Override
+ public TensorProcessor build() {
+ return new TensorProcessor(this);
+ }
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
new file mode 100644
index 00000000..3355b185
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java
@@ -0,0 +1,55 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common.ops;
+
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/** Casts a {@link TensorBuffer} to a specified data type. */
+public class CastOp implements TensorOperator {
+
+ private final DataType destinationType;
+
+ /**
+ * Constructs a CastOp.
+ *
+ * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than in
+ * a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}.
+ *
+ * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code
+ * destinationType}, the original buffer will be directly returned.
+ *
+ * @param destinationType: The type of the casted {@link TensorBuffer}.
+ * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8}
+ * nor {@link DataType#FLOAT32}.
+ */
+ public CastOp(DataType destinationType) {
+ SupportPreconditions.checkArgument(
+ destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32,
+ "Destination type " + destinationType + " is not supported.");
+ this.destinationType = destinationType;
+ }
+
+ @Override
+ public TensorBuffer apply(TensorBuffer input) {
+ if (input.getDataType() == destinationType) {
+ return input;
+ }
+ return TensorBuffer.createFrom(input, destinationType);
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
new file mode 100644
index 00000000..18817478
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java
@@ -0,0 +1,40 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common.ops;
+
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * Dequantizes a {@link TensorBuffer} with given {@code zeroPoint} and {@code scale}.
+ *
+ * <p>Note: The data type of output tensor is always {@code FLOAT32} except when the DequantizeOp is
+ * created effectively as an identity Op such as setting {@code zeroPoint} to 0 and {@code scale} to
+ * 1 (in this case, the output tensor is the same instance as input).
+ *
+ * <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link DequantizeOp} will be bypassed,
+ * which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful
+ * when passing in the quantization parameters that are extracted directly from the TFLite model
+ * flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read
+ * as 0.
+ */
+public class DequantizeOp extends NormalizeOp implements TensorOperator {
+
+ public DequantizeOp(float zeroPoint, float scale) {
+ // Quantization: f = (q - z) * s
+ super(zeroPoint, 1 / scale);
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
new file mode 100644
index 00000000..9db1388b
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java
@@ -0,0 +1,160 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common.ops;
+
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat;
+
+/**
+ * Normalizes a {@link TensorBuffer} with given mean and stddev: output = (input - mean) / stddev.
+ */
+public class NormalizeOp implements TensorOperator {
+
+ // mean.length should always be equal to stddev.length and always >= 1.
+ private final float[] mean;
+ private final float[] stddev;
+ private final int numChannels;
+ private final boolean isIdentityOp;
+
+ /**
+ * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
+ * satisfies:
+ *
+ * <pre>
+ * output = (input - mean) / stddev
+ * </pre>
+ *
+ * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the
+ * normalization. <br>
+ * 1. Both {@code mean} and {code stddev} are 0. <br>
+ * 2. {@code mean} is 0 and {stddev} is Infinity.
+ *
+ * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will
+ * happen, and original input will be directly returned in execution.
+ *
+ * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
+ * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 and
+ * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned.
+ *
+ * @param mean the mean value to be subtracted first.
+ * @param stddev the standard deviation value to divide then.
+ * @throws IllegalArgumentException if {@code stddev} is zero.
+ */
+ public NormalizeOp(float mean, float stddev) {
+ // Make exceptions to the cases that
+ // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters
+ // from a tensor which does not have the values populated in the metadata. The same situation
+ // may also happen to the quantization parameters.
+ // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization
+ // parameters from a tensor which does not have the values populated in the metadata, and then
+ // passing the parameters into the DequantizeOp.
+ // Bypass both of the two cases, by reseting stddev to 1.0f.
+ if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) {
+ stddev = 1.0f;
+ }
+
+ SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero.");
+ boolean meansIsZeroAndDevsIs1 = false;
+ if (mean == 0.0f && stddev == 1.0f) {
+ meansIsZeroAndDevsIs1 = true;
+ }
+
+ this.isIdentityOp = meansIsZeroAndDevsIs1;
+ this.mean = new float[] {mean};
+ this.stddev = new float[] {stddev};
+ this.numChannels = 1;
+ }
+
+ /**
+ * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which
+ * satisfies:
+ *
+ * <pre>
+ * // Pseudo code. [...][i] means a certain element whose channel id is i.
+ * output[...][i] = (input[...][i] - mean[i]) / stddev[i]
+ * </pre>
+ *
+ * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no
+ * computation will happen, and original input will be directly returned in execution.
+ *
+ * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at
+ * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set to
+ * 0 and all {@code stddev} are set to 1.
+ *
+ * @param mean the mean values to be subtracted first for each channel.
+ * @param stddev the standard deviation values to divide then for each channel.
+ * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different
+ * number of elements with {@code stddev}, or any of them is empty.
+ */
+ public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) {
+ SupportPreconditions.checkNotNull(mean, "Mean cannot be null");
+ SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null");
+ SupportPreconditions.checkArgument(
+ mean.length == stddev.length,
+ "Per channel normalization requires same number of means and stddevs");
+ SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty.");
+ this.mean = mean.clone();
+ this.stddev = stddev.clone();
+ boolean allMeansAreZeroAndAllDevsAre1 = true;
+ this.numChannels = mean.length;
+ for (int i = 0; i < numChannels; i++) {
+ SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero.");
+ if (this.stddev[i] != 1 || this.mean[i] != 0) {
+ allMeansAreZeroAndAllDevsAre1 = false;
+ }
+ }
+ this.isIdentityOp = allMeansAreZeroAndAllDevsAre1;
+ }
+
+ /**
+ * Applies the defined normalization on given tensor and returns the result.
+ *
+ * <p>Note: {@code input} is possibly the same instance with the output.
+ *
+ * @param input input tensor. It may be the same instance with the output.
+ * @return output tensor.
+ */
+ @Override
+ @NonNull
+ public TensorBuffer apply(@NonNull TensorBuffer input) {
+ if (isIdentityOp) {
+ return input;
+ }
+ int[] shape = input.getShape();
+ SupportPreconditions.checkArgument(
+ numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels),
+ "Number of means (stddevs) is not same with number of channels (size of last axis).");
+ // TODO(136750944): Eliminate the array copy here.
+ float[] values = input.getFloatArray();
+ int j = 0;
+ for (int i = 0; i < values.length; i++) {
+ values[i] = (values[i] - mean[j]) / stddev[j];
+ j = (j + 1) % numChannels;
+ }
+ TensorBuffer output;
+ if (input.isDynamic()) {
+ output = TensorBufferFloat.createDynamic(DataType.FLOAT32);
+ } else {
+ output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32);
+ }
+ output.loadArray(values, shape);
+ return output;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
new file mode 100644
index 00000000..8b3e82ae
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java
@@ -0,0 +1,41 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.common.ops;
+
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * Quantizes a {@link TensorBuffer} with given {@code zeroPoint} and {@code scale}.
+ *
+ * <p>Note: {@link QuantizeOp} does not cast output to UINT8, but only performs the quantization
+ * math on top of input. The data type of output tensor is always {@code FLOAT32} except that the Op
+ * is effectively an identity Op (in this case, the output tensor is the same instance as the
+ * input). To connect with quantized model, a {@link CastOp} is probably needed.
+ *
+ * <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link QuantizeOp} will be bypassed,
+ * which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful
+ * when passing in the quantization parameters that are extracted directly from the TFLite model
+ * flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read
+ * as 0.
+ */
+public class QuantizeOp extends NormalizeOp implements TensorOperator {
+
+ public QuantizeOp(float zeroPoint, float scale) {
+ // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s
+ super(-zeroPoint * scale, scale);
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
new file mode 100644
index 00000000..b9590bfd
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java
@@ -0,0 +1,80 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image;
+
+import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
+import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
+
+import android.graphics.Bitmap;
+import android.graphics.Bitmap.Config;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/** Holds a {@link Bitmap} and converts it to other image formats as needed. */
+final class BitmapContainer implements ImageContainer {
+
+ private final Bitmap bitmap;
+
+ /**
+ * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}.
+ *
+ * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888
+ */
+ static BitmapContainer create(Bitmap bitmap) {
+ return new BitmapContainer(bitmap);
+ }
+
+ private BitmapContainer(Bitmap bitmap) {
+ checkNotNull(bitmap, "Cannot load null bitmap.");
+ checkArgument(
+ bitmap.getConfig().equals(Config.ARGB_8888), "Only supports loading ARGB_8888 bitmaps.");
+ this.bitmap = bitmap;
+ }
+
+ @Override
+ public BitmapContainer clone() {
+ return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable()));
+ }
+
+ @Override
+ public Bitmap getBitmap() {
+ // Not making a defensive copy for performance considerations. During image processing,
+ // users may need to set and get the bitmap many times.
+ return bitmap;
+ }
+
+ @Override
+ public TensorBuffer getTensorBuffer(DataType dataType) {
+ TensorBuffer buffer = TensorBuffer.createDynamic(dataType);
+ ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer);
+ return buffer;
+ }
+
+ @Override
+ public int getWidth() {
+ return bitmap.getWidth();
+ }
+
+ @Override
+ public int getHeight() {
+ return bitmap.getHeight();
+ }
+
+ @Override
+ public ColorSpaceType getColorSpaceType() {
+ return ColorSpaceType.fromBitmapConfig(bitmap.getConfig());
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
new file mode 100644
index 00000000..1a463305
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java
@@ -0,0 +1,244 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image;
+
+import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
+
+import android.graphics.RectF;
+import java.nio.ByteBuffer;
+import java.nio.FloatBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * Helper class for converting values that represents bounding boxes into rectangles.
+ *
+ * <p>The class provides a static function to create bounding boxes as {@link RectF} from different
+ * types of configurations.
+ *
+ * <p>Generally, a bounding box could be represented by 4 float values, but the values could be
+ * interpreted in many ways. We now support 3 {@link Type} of configurations, and the order of
+ * elements in each type is configurable as well.
+ */
+public final class BoundingBoxUtil {
+
+ /** Denotes how a bounding box is represented. */
+ public enum Type {
+ /**
+ * Represents the bounding box by using the combination of boundaries, {left, top, right,
+ * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated by an
+ * index array.
+ */
+ BOUNDARIES,
+ /**
+ * Represents the bounding box by using the upper_left corner, width and height. The default
+ * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an
+ * index array.
+ */
+ UPPER_LEFT,
+ /**
+ * Represents the bounding box by using the center of the box, width and height. The default
+ * order is {center_x, center_y, width, height}. Other orders can be indicated by an index
+ * array.
+ */
+ CENTER,
+ }
+
+ /** Denotes if the coordinates are actual pixels or relative ratios. */
+ public enum CoordinateType {
+ /** The coordinates are relative ratios in range [0, 1]. */
+ RATIO,
+ /** The coordinates are actual pixel values. */
+ PIXEL
+ }
+
+ /**
+ * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes.
+ *
+ * @param tensor holds the data representing some boxes.
+ * @param valueIndex denotes the order of the elements defined in each bounding box type. An empty
+ * index array represent the default order of each bounding box type. For example, to denote
+ * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1, 2,
+ * 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}.
+ * <p>The index array can be applied to all bounding box types to adjust the order of their
+ * corresponding underlying elements.
+ * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The
+ * size of that dimension is required to be 4. Index here starts from 0. For example, if the
+ * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is also
+ * supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the
+ * axis is likely to be 1 (or -1, equivalently).
+ * @param type defines how values should be converted into boxes. See {@link Type}
+ * @param coordinateType defines how values are interpreted to coordinates. See {@link
+ * CoordinateType}
+ * @param height the height of the image which the boxes belong to. Only has effects when {@code
+ * coordinateType} is {@link CoordinateType#RATIO}
+ * @param width the width of the image which the boxes belong to. Only has effects when {@code
+ * coordinateType} is {@link CoordinateType#RATIO}
+ * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except
+ * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code
+ * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a list
+ * of 20 bounding boxes.
+ * @throws IllegalArgumentException if size of bounding box dimension (set by {@code
+ * boundingBoxAxis}) is not 4.
+ * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)} where
+ * {@code D} is the number of dimensions of the {@code tensor}.
+ * @throws IllegalArgumentException if {@code tensor} has data type other than {@link
+ * DataType#FLOAT32}.
+ */
+ public static List<RectF> convert(
+ TensorBuffer tensor,
+ int[] valueIndex,
+ int boundingBoxAxis,
+ Type type,
+ CoordinateType coordinateType,
+ int height,
+ int width) {
+ int[] shape = tensor.getShape();
+ checkArgument(
+ boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length,
+ String.format(
+ "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input"
+ + " tensor (shape=%s)",
+ boundingBoxAxis, Arrays.toString(shape)));
+ if (boundingBoxAxis < 0) {
+ boundingBoxAxis = shape.length + boundingBoxAxis;
+ }
+ checkArgument(
+ shape[boundingBoxAxis] == 4,
+ String.format(
+ "Size of bounding box dimension %d is not 4. Got %d in shape %s",
+ boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape)));
+ checkArgument(
+ valueIndex.length == 4,
+ String.format(
+ "Bounding box index array length %d is not 4. Got index array %s",
+ valueIndex.length, Arrays.toString(valueIndex)));
+ checkArgument(
+ tensor.getDataType() == DataType.FLOAT32,
+ "Bounding Boxes only create from FLOAT32 buffers. Got: " + tensor.getDataType().name());
+ List<RectF> boundingBoxList = new ArrayList<>();
+ // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and its
+ // four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by
+ // i * 4b + k * b + j.
+ int a = 1;
+ for (int i = 0; i < boundingBoxAxis; i++) {
+ a *= shape[i];
+ }
+ int b = 1;
+ for (int i = boundingBoxAxis + 1; i < shape.length; i++) {
+ b *= shape[i];
+ }
+ float[] values = new float[4];
+ ByteBuffer byteBuffer = tensor.getBuffer();
+ byteBuffer.rewind();
+ FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
+ for (int i = 0; i < a; i++) {
+ for (int j = 0; j < b; j++) {
+ for (int k = 0; k < 4; k++) {
+ values[k] = floatBuffer.get((i * 4 + k) * b + j);
+ }
+ boundingBoxList.add(
+ convertOneBoundingBox(values, valueIndex, type, coordinateType, height, width));
+ }
+ }
+ byteBuffer.rewind();
+ return boundingBoxList;
+ }
+
+ private static RectF convertOneBoundingBox(
+ float[] values,
+ int[] valueIndex,
+ Type type,
+ CoordinateType coordinateType,
+ int height,
+ int width) {
+ float[] orderedValues = new float[4];
+ for (int i = 0; i < 4; i++) {
+ orderedValues[i] = values[valueIndex[i]];
+ }
+ return convertOneBoundingBox(orderedValues, type, coordinateType, height, width);
+ }
+
+ private static RectF convertOneBoundingBox(
+ float[] values, Type type, CoordinateType coordinateType, int height, int width) {
+ switch (type) {
+ case BOUNDARIES:
+ return convertFromBoundaries(values, coordinateType, height, width);
+ case UPPER_LEFT:
+ return convertFromUpperLeft(values, coordinateType, height, width);
+ case CENTER:
+ return convertFromCenter(values, coordinateType, height, width);
+ }
+ throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type);
+ }
+
+ private static RectF convertFromBoundaries(
+ float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
+ float left = values[0];
+ float top = values[1];
+ float right = values[2];
+ float bottom = values[3];
+ return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
+ }
+
+ private static RectF convertFromUpperLeft(
+ float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
+ float left = values[0];
+ float top = values[1];
+ float right = values[0] + values[2];
+ float bottom = values[1] + values[3];
+ return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
+ }
+
+ private static RectF convertFromCenter(
+ float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) {
+ float centerX = values[0];
+ float centerY = values[1];
+ float w = values[2];
+ float h = values[3];
+
+ float left = centerX - w / 2;
+ float top = centerY - h / 2;
+ float right = centerX + w / 2;
+ float bottom = centerY + h / 2;
+ return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType);
+ }
+
+ private static RectF getRectF(
+ float left,
+ float top,
+ float right,
+ float bottom,
+ int imageHeight,
+ int imageWidth,
+ CoordinateType coordinateType) {
+ if (coordinateType == CoordinateType.PIXEL) {
+ return new RectF(
+ left, top, right, bottom);
+ } else if (coordinateType == CoordinateType.RATIO) {
+ return new RectF(
+ left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight);
+ } else {
+ throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType);
+ }
+ }
+
+ // Private constructor to prevent initialization.
+ private BoundingBoxUtil() {}
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
new file mode 100644
index 00000000..e92d0959
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java
@@ -0,0 +1,212 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image;
+
+import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
+
+import android.graphics.Bitmap;
+import android.graphics.Bitmap.Config;
+import java.util.Arrays;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/** Represents the type of color space of an image. */
+public enum ColorSpaceType {
+ /** Each pixel has red, green, and blue color components. */
+ RGB {
+
+ // The channel axis should always be 3 for RGB images.
+ private static final int CHANNEL_VALUE = 3;
+
+ @Override
+ Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
+ return ImageConversions.convertRgbTensorBufferToBitmap(buffer);
+ }
+
+ @Override
+ int getChannelValue() {
+ return CHANNEL_VALUE;
+ }
+
+ @Override
+ int[] getNormalizedShape(int[] shape) {
+ switch (shape.length) {
+ // The shape is in (h, w, c) format.
+ case 3:
+ return insertValue(shape, BATCH_DIM, BATCH_VALUE);
+ case 4:
+ return shape;
+ default:
+ throw new IllegalArgumentException(
+ getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
+ }
+ }
+
+ @Override
+ String getShapeInfoMessage() {
+ return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels"
+ + " representing R, G, B in order. ";
+ }
+
+ @Override
+ Config toBitmapConfig() {
+ return Config.ARGB_8888;
+ }
+ },
+
+ /** Each pixel is a single element representing only the amount of light. */
+ GRAYSCALE {
+
+ // The channel axis should always be 1 for grayscale images.
+ private static final int CHANNEL_VALUE = 1;
+
+ @Override
+ Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) {
+ return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer);
+ }
+
+ @Override
+ int getChannelValue() {
+ return CHANNEL_VALUE;
+ }
+
+ @Override
+ int[] getNormalizedShape(int[] shape) {
+ switch (shape.length) {
+ // The shape is in (h, w) format.
+ case 2:
+ int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE);
+ return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE);
+ case 4:
+ return shape;
+ default:
+ // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since they
+ // both have three dimensions, it will require extra info to differentiate between them.
+ // Since we haven't encountered real use cases of these two shapes, they are not supported
+ // at this moment to avoid confusion. We may want to revisit it in the future.
+ throw new IllegalArgumentException(
+ getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
+ }
+ }
+
+ @Override
+ String getShapeInfoMessage() {
+ return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). ";
+ }
+
+ @Override
+ Config toBitmapConfig() {
+ return Config.ALPHA_8;
+ }
+ };
+
+ private static final int BATCH_DIM = 0; // The first element of the normalizaed shape.
+ private static final int BATCH_VALUE = 1; // The batch axis should always be one.
+ private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape.
+ private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape.
+ private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape.
+
+ /**
+ * Converts a bitmap configuration into the corresponding color space type.
+ *
+ * @throws IllegalArgumentException if the config is unsupported
+ */
+ static ColorSpaceType fromBitmapConfig(Config config) {
+ switch (config) {
+ case ARGB_8888:
+ return ColorSpaceType.RGB;
+ case ALPHA_8:
+ return ColorSpaceType.GRAYSCALE;
+ default:
+ throw new IllegalArgumentException(
+ "Bitmap configuration: " + config + ", is not supported yet.");
+ }
+ }
+
+ /**
+ * Verifies if the given shape matches the color space type.
+ *
+ * @throws IllegalArgumentException if {@code shape} does not match the color space type
+ */
+ void assertShape(int[] shape) {
+ int[] normalizedShape = getNormalizedShape(shape);
+ checkArgument(
+ isValidNormalizedShape(normalizedShape),
+ getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape));
+ }
+
+ /**
+ * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space type.
+ *
+ * @throws IllegalArgumentException if the shape of buffer does not match the color space type
+ */
+ abstract Bitmap convertTensorBufferToBitmap(TensorBuffer buffer);
+
+ /**
+ * Returns the width of the given shape corresponding to the color space type.
+ *
+ * @throws IllegalArgumentException if {@code shape} does not match the color space type
+ */
+ int getWidth(int[] shape) {
+ assertShape(shape);
+ return getNormalizedShape(shape)[WIDTH_DIM];
+ }
+
+ /**
+ * Returns the height of the given shape corresponding to the color space type.
+ *
+ * @throws IllegalArgumentException if {@code shape} does not match the color space type
+ */
+ int getHeight(int[] shape) {
+ assertShape(shape);
+ return getNormalizedShape(shape)[HEIGHT_DIM];
+ }
+
+ abstract int getChannelValue();
+
+ /**
+ * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have
+ * batch or channel axis.
+ */
+ abstract int[] getNormalizedShape(int[] shape);
+
+ abstract String getShapeInfoMessage();
+
+ /** Converts the color space type to the corresponding bitmap config. */
+ abstract Config toBitmapConfig();
+
+ /** Inserts a value at the specified position and return the new array. */
+ private static int[] insertValue(int[] array, int pos, int value) {
+ int[] newArray = new int[array.length + 1];
+ for (int i = 0; i < pos; i++) {
+ newArray[i] = array[i];
+ }
+ newArray[pos] = value;
+ for (int i = pos + 1; i < newArray.length; i++) {
+ newArray[i] = array[i - 1];
+ }
+ return newArray;
+ }
+
+ protected boolean isValidNormalizedShape(int[] shape) {
+ if (shape[BATCH_DIM] == BATCH_VALUE
+ && shape[HEIGHT_DIM] > 0
+ && shape[WIDTH_DIM] > 0
+ && shape[CHANNEL_DIM] == getChannelValue()) {
+ return true;
+ }
+ return false;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
new file mode 100644
index 00000000..1a145de7
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java
@@ -0,0 +1,55 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image;
+
+import android.graphics.Bitmap;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * Handles image conversion across different image types.
+ *
+ * <p>An {@link ImageContainer} should support the conversion between the underlying image format to
+ * the following image types:
+ *
+ * <ul>
+ * <li>{@link Bitmap}
+ * <li>{@link TensorBuffer} of the specified data type.
+ * </ul>
+ */
+interface ImageContainer {
+
+ /** Performs deep copy of the {@link ImageContainer}. */
+ ImageContainer clone();
+
+ /** Returns the width of the image. */
+ int getWidth();
+
+ /** Returns the height of the image. */
+ int getHeight();
+
+ /** Gets the {@link Bitmap} representation of the underlying image format. */
+ Bitmap getBitmap();
+
+ /**
+ * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the
+ * underlying image format.
+ */
+ TensorBuffer getTensorBuffer(DataType dataType);
+
+ /** Returns the color space type of the image. */
+ ColorSpaceType getColorSpaceType();
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
new file mode 100644
index 00000000..d6e567a2
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java
@@ -0,0 +1,146 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image;
+
+import android.graphics.Bitmap;
+import android.graphics.Color;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * Implements some stateless image conversion methods.
+ *
+ * <p>This class is an internal helper for {@link org.tensorflow.lite.support.image}.
+ */
+class ImageConversions {
+
+ /**
+ * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap.
+ *
+ * <p>Data in buffer will be converted into integer to match the Bitmap API.
+ *
+ * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3)
+ * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3)
+ */
+ static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) {
+ int[] shape = buffer.getShape();
+ ColorSpaceType rgb = ColorSpaceType.RGB;
+ rgb.assertShape(shape);
+
+ int h = rgb.getHeight(shape);
+ int w = rgb.getWidth(shape);
+ Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig());
+
+ // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
+ int[] intValues = new int[w * h];
+ int[] rgbValues = buffer.getIntArray();
+ for (int i = 0, j = 0; i < intValues.length; i++) {
+ int r = rgbValues[j++];
+ int g = rgbValues[j++];
+ int b = rgbValues[j++];
+ intValues[i] = Color.rgb(r, g, b);
+ }
+ bitmap.setPixels(intValues, 0, w, 0, 0, w, h);
+
+ return bitmap;
+ }
+
+ /**
+ * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap.
+ *
+ * <p>Data in buffer will be converted into integer to match the Bitmap API.
+ *
+ * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w)
+ * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1)
+ */
+ static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) {
+ // Convert buffer into Uint8 as needed.
+ TensorBuffer uint8Buffer =
+ buffer.getDataType() == DataType.UINT8
+ ? buffer
+ : TensorBuffer.createFrom(buffer, DataType.UINT8);
+
+ int[] shape = uint8Buffer.getShape();
+ ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE;
+ grayscale.assertShape(shape);
+
+ // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config config)`
+ // seems to work for internal Android testing framework, but it actually doesn't work for the
+ // real Android environment.
+ //
+ // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to load
+ // the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out.
+ // Note: for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work.
+ Bitmap bitmap =
+ Bitmap.createBitmap(
+ grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig());
+ uint8Buffer.getBuffer().rewind();
+ bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer());
+ return bitmap;
+ }
+
+ /**
+ * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose memory
+ * is already allocated, or could be dynamically allocated.
+ *
+ * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888
+ * config.
+ * @param buffer The destination of the conversion. Needs to be created in advance. If it's
+ * fixed-size, its flat size should be w*h*3.
+ * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match.
+ */
+ static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) {
+ int w = bitmap.getWidth();
+ int h = bitmap.getHeight();
+ int[] intValues = new int[w * h];
+ bitmap.getPixels(intValues, 0, w, 0, 0, w, h);
+ // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time.
+ int flatSize = w * h * 3;
+ int[] shape = new int[] {h, w, 3};
+ switch (buffer.getDataType()) {
+ case UINT8:
+ byte[] byteArr = new byte[w * h * 3];
+ for (int i = 0, j = 0; i < intValues.length; i++) {
+ byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff);
+ byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff);
+ byteArr[j++] = (byte) (intValues[i] & 0xff);
+ }
+ ByteBuffer byteBuffer = ByteBuffer.allocateDirect(flatSize);
+ byteBuffer.order(ByteOrder.nativeOrder());
+ byteBuffer.put(byteArr);
+ buffer.loadBuffer(byteBuffer, shape);
+ break;
+ case FLOAT32:
+ float[] floatArr = new float[w * h * 3];
+ for (int i = 0, j = 0; i < intValues.length; i++) {
+ floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff);
+ floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff);
+ floatArr[j++] = (float) (intValues[i] & 0xff);
+ }
+ buffer.loadArray(floatArr, shape);
+ break;
+ default:
+ // Should never happen.
+ throw new IllegalStateException(
+ "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported.");
+ }
+ }
+
+ // Hide the constructor as the class is static.
+ private ImageConversions() {}
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
new file mode 100644
index 00000000..1e546634
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java
@@ -0,0 +1,43 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image;
+
+import android.graphics.PointF;
+import org.tensorflow.lite.support.common.Operator;
+
+/** Operates a TensorImage object. Used in ImageProcessor. */
+public interface ImageOperator extends Operator<TensorImage> {
+ /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */
+ @Override
+ TensorImage apply(TensorImage image);
+
+ /** Computes the width of the expected output image when input image size is given. */
+ int getOutputImageWidth(int inputImageHeight, int inputImageWidth);
+
+ /** Computes the height of the expected output image when input image size is given. */
+ int getOutputImageHeight(int inputImageHeight, int inputImageWidth);
+
+ /**
+ * Transforms a point from coordinates system of the result image back to the one of the input
+ * image.
+ *
+ * @param point the point from the result coordinates system.
+ * @param inputImageHeight the height of input image.
+ * @param inputImageWidth the width of input image.
+ * @return the point with the coordinates from the coordinates system of the input image.
+ */
+ PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
new file mode 100644
index 00000000..e1ef1309
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java
@@ -0,0 +1,198 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image;
+
+import android.graphics.PointF;
+import android.graphics.RectF;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.ListIterator;
+import org.tensorflow.lite.support.common.Operator;
+import org.tensorflow.lite.support.common.SequentialProcessor;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.image.ops.Rot90Op;
+import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper;
+
+/**
+ * ImageProcessor is a helper class for preprocessing and postprocessing {@link TensorImage}. It
+ * could transform a {@link TensorImage} to another by executing a chain of {@link ImageOperator}.
+ *
+ * <p>Example Usage:
+ *
+ * <pre>
+ * ImageProcessor processor = new ImageProcessor.Builder()
+ * .add(new ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOR)
+ * .add(new Rot90Op())
+ * .add(new NormalizeOp(127.5f, 127.5f))
+ * .build();
+ * TensorImage anotherTensorImage = processor.process(tensorImage);
+ * </pre>
+ *
+ * <p><b>WARNING:</b> Instances of an {@code ImageProcessor} are <b>not</b> thread-safe with {@link
+ * #updateNumberOfRotations}. Updating the number of rotations and then processing images (using
+ * {@link #process}) must be protected from concurrent access. It is recommended to create separate
+ * {@code ImageProcessor} instances for each thread. If multiple threads access a {@code
+ * ImageProcessor} concurrently, it must be synchronized externally.
+ *
+ * @see ImageProcessor.Builder to build a {@link ImageProcessor} instance
+ * @see ImageProcessor#process(TensorImage) to apply the processor on a {@link TensorImage}
+ */
+public class ImageProcessor extends SequentialProcessor<TensorImage> {
+ private ImageProcessor(Builder builder) {
+ super(builder);
+ }
+
+ /**
+ * Transforms a point from coordinates system of the result image back to the one of the input
+ * image.
+ *
+ * @param point the point from the result coordinates system.
+ * @param inputImageHeight the height of input image.
+ * @param inputImageWidth the width of input image.
+ * @return the point with the coordinates from the coordinates system of the input image.
+ */
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ List<Integer> widths = new ArrayList<>();
+ List<Integer> heights = new ArrayList<>();
+ int currentWidth = inputImageWidth;
+ int currentHeight = inputImageHeight;
+ for (Operator<TensorImage> op : operatorList) {
+ widths.add(currentWidth);
+ heights.add(currentHeight);
+ ImageOperator imageOperator = (ImageOperator) op;
+ int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth);
+ int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth);
+ currentHeight = newHeight;
+ currentWidth = newWidth;
+ }
+ ListIterator<Operator<TensorImage>> opIterator = operatorList.listIterator(operatorList.size());
+ ListIterator<Integer> widthIterator = widths.listIterator(widths.size());
+ ListIterator<Integer> heightIterator = heights.listIterator(heights.size());
+ while (opIterator.hasPrevious()) {
+ ImageOperator imageOperator = (ImageOperator) opIterator.previous();
+ int height = heightIterator.previous();
+ int width = widthIterator.previous();
+ point = imageOperator.inverseTransform(point, height, width);
+ }
+ return point;
+ }
+
+ /**
+ * Transforms a rectangle from coordinates system of the result image back to the one of the input
+ * image.
+ *
+ * @param rect the rectangle from the result coordinates system.
+ * @param inputImageHeight the height of input image.
+ * @param inputImageWidth the width of input image.
+ * @return the rectangle with the coordinates from the coordinates system of the input image.
+ */
+ public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) {
+ // when rotation is involved, corner order may change - top left changes to bottom right, .etc
+ PointF p1 =
+ inverseTransform(new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth);
+ PointF p2 =
+ inverseTransform(new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth);
+ return new RectF(
+ Math.min(p1.x, p2.x), Math.min(p1.y, p2.y), Math.max(p1.x, p2.x), Math.max(p1.y, p2.y));
+ }
+
+ /**
+ * The Builder to create an ImageProcessor, which could be executed later.
+ *
+ * @see #add(TensorOperator) to add a general TensorOperator
+ * @see #add(ImageOperator) to add an ImageOperator
+ * @see #build() complete the building process and get a built Processor
+ */
+ public static class Builder extends SequentialProcessor.Builder<TensorImage> {
+ public Builder() {
+ super();
+ }
+
+ /**
+ * Adds an {@link ImageOperator} into the Operator chain.
+ *
+ * @param op the Operator instance to be executed then
+ */
+ public Builder add(ImageOperator op) {
+ super.add(op);
+ return this;
+ }
+
+ /**
+ * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls
+ * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by transforming
+ * the underlying {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
+ *
+ * @param op the Operator instance to be executed then
+ */
+ public Builder add(TensorOperator op) {
+ return add(new TensorOperatorWrapper(op));
+ }
+
+ /** Completes the building process and gets the {@link ImageProcessor} instance. */
+ @Override
+ public ImageProcessor build() {
+ return new ImageProcessor(this);
+ }
+ }
+
+ /**
+ * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}.
+ *
+ * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
+ * then processing images (using {@link #process}) must be protected from concurrent access with
+ * additional synchronization.
+ *
+ * @param k the number of rotations
+ * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
+ * ImageProcessor}
+ */
+ public void updateNumberOfRotations(int k) {
+ updateNumberOfRotations(k, /*occurrence=*/ 0);
+ }
+
+ /**
+ * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in this
+ * {@link ImageProcessor}.
+ *
+ * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and
+ * then processing images (using {@link #process}) must be protected from concurrent access with
+ * additional synchronization.
+ *
+ * @param k the number of rotations
+ * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For
+ * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be
+ * set to 1.
+ * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the
+ * number of {@link Rot90Op} in this {@link ImageProcessor}
+ * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link
+ * ImageProcessor}
+ */
+ public synchronized void updateNumberOfRotations(int k, int occurrence) {
+ SupportPreconditions.checkState(
+ operatorIndex.containsKey(Rot90Op.class.getName()),
+ "The Rot90Op has not been added to the ImageProcessor.");
+
+ List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName());
+ SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence");
+
+ // The index of the Rot90Op to be replaced in operatorList.
+ int index = indexes.get(occurrence);
+ Rot90Op newRot = new Rot90Op(k);
+ operatorList.set(index, newRot);
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
new file mode 100644
index 00000000..d047a8e0
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java
@@ -0,0 +1,90 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image;
+
+import android.graphics.Bitmap;
+import android.util.Log;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/** Holds a {@link TensorBuffer} and converts it to other image formats as needed. */
+final class TensorBufferContainer implements ImageContainer {
+
+ private final TensorBuffer buffer;
+ private final ColorSpaceType colorSpaceType;
+ private static final String TAG = TensorBufferContainer.class.getSimpleName();
+
+ /**
+ * Creates a {@link TensorBufferContainer} object with the specified {@link
+ * TensorImage#ColorSpaceType}.
+ *
+ * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the
+ * specified color space type
+ */
+ static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
+ return new TensorBufferContainer(buffer, colorSpaceType);
+ }
+
+ private TensorBufferContainer(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
+ colorSpaceType.assertShape(buffer.getShape());
+ this.buffer = buffer;
+ this.colorSpaceType = colorSpaceType;
+ }
+
+ @Override
+ public TensorBufferContainer clone() {
+ return create(TensorBuffer.createFrom(buffer, buffer.getDataType()), colorSpaceType);
+ }
+
+ @Override
+ public Bitmap getBitmap() {
+ if (buffer.getDataType() != DataType.UINT8) {
+ // Print warning instead of throwing an exception. When using float models, users may want to
+ // convert the resulting float image into Bitmap. That's fine to do so, as long as they are
+ // aware of the potential accuracy lost when casting to uint8.
+ Log.w(
+ TAG,
+ "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap"
+ + " will cause numeric casting and clamping on the data value.");
+ }
+
+ return colorSpaceType.convertTensorBufferToBitmap(buffer);
+ }
+
+ @Override
+ public TensorBuffer getTensorBuffer(DataType dataType) {
+ // If the data type of buffer is desired, return it directly. Not making a defensive copy for
+ // performance considerations. During image processing, users may need to set and get the
+ // TensorBuffer many times.
+ // Otherwise, create another one with the expected data type.
+ return buffer.getDataType() == dataType ? buffer : TensorBuffer.createFrom(buffer, dataType);
+ }
+
+ @Override
+ public int getWidth() {
+ return colorSpaceType.getWidth(buffer.getShape());
+ }
+
+ @Override
+ public int getHeight() {
+ return colorSpaceType.getHeight(buffer.getShape());
+ }
+
+ @Override
+ public ColorSpaceType getColorSpaceType() {
+ return colorSpaceType;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
new file mode 100644
index 00000000..96cae716
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java
@@ -0,0 +1,312 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image;
+
+import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
+
+import android.graphics.Bitmap;
+import java.nio.ByteBuffer;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * TensorImage is the wrapper class for Image object. When using image processing utils in
+ * TFLite.support library, it's common to convert image objects in variant types to TensorImage at
+ * first.
+ *
+ * <p>At present, only RGB images are supported, and the A channel is always ignored.
+ *
+ * <p>Details of data storage: a {@link TensorImage} object may have 2 potential sources of truth: a
+ * {@link Bitmap} or a {@link TensorBuffer}. {@link TensorImage} maintains the state and only
+ * converts one to the other when needed. A typical use case of {@link TensorImage} is to first load
+ * a {@link Bitmap} image, then process it using {@link ImageProcessor}, and finally get the
+ * underlying {@link ByteBuffer} of the {@link TensorBuffer} and feed it into the TFLite
+ * interpreter.
+ *
+ * <p>IMPORTANT: to achieve the best performance, {@link TensorImage} avoids copying data whenever
+ * it's possible. Therefore, it doesn't own its data. Callers should not modify data objects those
+ * are passed to {@link TensorImage#load(Bitmap)} or {@link TensorImage#load(TensorBuffer)}.
+ *
+ * <p>IMPORTANT: all methods are not proved thread-safe.
+ *
+ * @see ImageProcessor which is often used for transforming a {@link TensorImage}.
+ */
+// TODO(b/138907116): Support loading images from TensorBuffer with properties.
+// TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary.
+public class TensorImage {
+
+ private final DataType dataType;
+ private ImageContainer container = null;
+
+ /**
+ * Initializes a {@link TensorImage} object.
+ *
+ * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link
+ * #TensorImage(DataType)} if other data types are preferred.
+ */
+ public TensorImage() {
+ this(DataType.UINT8);
+ }
+
+ /**
+ * Initializes a {@link TensorImage} object with the specified data type.
+ *
+ * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage},
+ * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be
+ * converted to the specified data type.
+ *
+ * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of
+ * the image being loaded to this {@link TensorImage}.
+ *
+ * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is
+ * always fixed during the lifetime of the {@link TensorImage}. To convert the data type, use
+ * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the
+ * same time.
+ * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor
+ * {@link DataType#FLOAT32}
+ */
+ public TensorImage(DataType dataType) {
+ checkArgument(
+ dataType == DataType.UINT8 || dataType == DataType.FLOAT32,
+ "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted");
+ this.dataType = dataType;
+ }
+
+ /**
+ * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link Bitmap} .
+ *
+ * @see TensorImage#load(Bitmap) for reusing the object when it's expensive to create objects
+ * frequently, because every call of {@code fromBitmap} creates a new {@link TensorImage}.
+ */
+ public static TensorImage fromBitmap(Bitmap bitmap) {
+ TensorImage image = new TensorImage();
+ image.load(bitmap);
+ return image;
+ }
+
+ /**
+ * Creates a deep-copy of a given {@link TensorImage} with the desired data type.
+ *
+ * @param src the {@link TensorImage} to copy from
+ * @param dataType the expected data type of newly created {@link TensorImage}
+ * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code
+ * dataType}
+ */
+ public static TensorImage createFrom(TensorImage src, DataType dataType) {
+ TensorImage dst = new TensorImage(dataType);
+ dst.container = src.container.clone();
+ return dst;
+ }
+
+ /**
+ * Loads a {@link Bitmap} image object into this {@link TensorImage}.
+ *
+ * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric
+ * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}, where the {@link Bitmap} will be converted into a {@link TensorBuffer}.
+ *
+ * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The
+ * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well.
+ * In this method, we perform a zero-copy approach for that bitmap, by simply holding its
+ * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary.
+ *
+ * <p>Note: to get the best performance, please load images in the same shape to avoid memory
+ * re-allocation.
+ *
+ * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888
+ */
+ public void load(Bitmap bitmap) {
+ container = BitmapContainer.create(bitmap);
+ }
+
+ /**
+ * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels
+ * inside.
+ *
+ * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32},
+ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}.
+ *
+ * @param pixels the RGB pixels representing the image
+ * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
+ * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
+ */
+ public void load(float[] pixels, int[] shape) {
+ TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
+ buffer.loadArray(pixels, shape);
+ load(buffer);
+ }
+
+ /**
+ * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels inside.
+ *
+ * <p>Note: numeric casting and clamping will be applied to convert the values into the data type
+ * of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link #getBuffer}.
+ *
+ * @param pixels the RGB pixels representing the image
+ * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3)
+ * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
+ */
+ public void load(int[] pixels, int[] shape) {
+ TensorBuffer buffer = TensorBuffer.createDynamic(getDataType());
+ buffer.loadArray(pixels, shape);
+ load(buffer);
+ }
+
+ /**
+ * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB.
+ *
+ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
+ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}.
+ *
+ * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or
+ * (1, h, w, 3)
+ * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3)
+ */
+ public void load(TensorBuffer buffer) {
+ load(buffer, ColorSpaceType.RGB);
+ }
+
+ /**
+ * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ColorSapceType}.
+ *
+ * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage},
+ * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link
+ * #getBuffer}.
+ *
+ * @throws IllegalArgumentException if the shape of buffer does not match the color space type
+ * @see ColorSpaceType#assertShape
+ */
+ public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) {
+ container = TensorBufferContainer.create(buffer, colorSpaceType);
+ }
+
+ /**
+ * Returns a {@link Bitmap} representation of this {@link TensorImage}.
+ *
+ * <p>Numeric casting and clamping will be applied if the stored data is not uint8.
+ *
+ * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code
+ * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work.
+ *
+ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
+ * concern, but if modification is necessary, please make a copy.
+ *
+ * @return a reference to a {@link Bitmap} in {@code ARGB_8888} config ("A" channel is always
+ * opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of this {@link
+ * TensorBuffer}.
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ */
+ public Bitmap getBitmap() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getBitmap();
+ }
+
+ /**
+ * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected data
+ * type.
+ *
+ * <p>Numeric casting and clamping will be applied if the stored data is different from the data
+ * type of the {@link TensorImage}.
+ *
+ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
+ * concern, but if modification is necessary, please make a copy.
+ *
+ * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}.
+ *
+ * @return a reference to a {@link ByteBuffer} which holds the image data
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ */
+ public ByteBuffer getBuffer() {
+ return getTensorBuffer().getBuffer();
+ }
+
+ /**
+ * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected
+ * data type.
+ *
+ * <p>Numeric casting and clamping will be applied if the stored data is different from the data
+ * type of the {@link TensorImage}.
+ *
+ * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance
+ * concern, but if modification is necessary, please make a copy.
+ *
+ * @return a reference to a {@link TensorBuffer} which holds the image data
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ */
+ public TensorBuffer getTensorBuffer() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getTensorBuffer(dataType);
+ }
+
+ /**
+ * Gets the data type of this {@link TensorImage}.
+ *
+ * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are
+ * supported.
+ */
+ public DataType getDataType() {
+ return dataType;
+ }
+
+ /**
+ * Gets the color space type of this {@link TensorImage}.
+ *
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ */
+ public ColorSpaceType getColorSpaceType() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getColorSpaceType();
+ }
+
+ /**
+ * Gets the image width.
+ *
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ * @throws IllegalArgumentException if the underlying data is corrupted
+ */
+ public int getWidth() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getWidth();
+ }
+
+ /**
+ * Gets the image height.
+ *
+ * @throws IllegalStateException if the {@link TensorImage} never loads data
+ * @throws IllegalArgumentException if the underlying data is corrupted
+ */
+ public int getHeight() {
+ if (container == null) {
+ throw new IllegalStateException("No image has been loaded yet.");
+ }
+
+ return container.getHeight();
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
new file mode 100644
index 00000000..35606dd6
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java
@@ -0,0 +1,89 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image.ops;
+
+import android.graphics.Bitmap;
+import android.graphics.PointF;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.support.image.ImageOperator;
+import org.tensorflow.lite.support.image.TensorImage;
+
+/**
+ * As a computation unit for processing images, it can resize an image to user-specified size.
+ *
+ * <p>It interpolates pixels when image is stretched, and discards pixels when image is compressed.
+ *
+ * @see ResizeWithCropOrPadOp for resizing without content distortion.
+ */
+public class ResizeOp implements ImageOperator {
+
+ /** Algorithms for resizing. */
+ public enum ResizeMethod {
+ BILINEAR,
+ NEAREST_NEIGHBOR
+ }
+
+ private final int targetHeight;
+ private final int targetWidth;
+ private final boolean useBilinear;
+
+ /**
+ * Creates a ResizeOp which can resize images to specified size in specified method.
+ *
+ * @param targetHeight: The expected height of resized image.
+ * @param targetWidth: The expected width of resized image.
+ * @param resizeMethod: The algorithm to use for resizing. Options: {@link ResizeMethod}
+ */
+ public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) {
+ this.targetHeight = targetHeight;
+ this.targetWidth = targetWidth;
+ useBilinear = (resizeMethod == ResizeMethod.BILINEAR);
+ }
+
+ /**
+ * Applies the defined resizing on given image and returns the result.
+ *
+ * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
+ * with the output.
+ *
+ * @param image input image.
+ * @return output image.
+ */
+ @Override
+ @NonNull
+ public TensorImage apply(@NonNull TensorImage image) {
+ Bitmap scaled =
+ Bitmap.createScaledBitmap(image.getBitmap(), targetWidth, targetHeight, useBilinear);
+ image.load(scaled);
+ return image;
+ }
+
+ @Override
+ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
+ return targetHeight;
+ }
+
+ @Override
+ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
+ return targetWidth;
+ }
+
+ @Override
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ return new PointF(
+ point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight);
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
new file mode 100644
index 00000000..404429ef
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java
@@ -0,0 +1,125 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image.ops;
+
+import android.graphics.Bitmap;
+import android.graphics.Bitmap.Config;
+import android.graphics.Canvas;
+import android.graphics.PointF;
+import android.graphics.Rect;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.support.image.ImageOperator;
+import org.tensorflow.lite.support.image.TensorImage;
+
+/**
+ * As a computation unit for processing images, it could resize image to predefined size.
+ *
+ * <p>It will not stretch or compress the content of image. However, to fit the new size, it crops
+ * or pads pixels. When it crops image, it performs a center-crop; when it pads pixels, it performs
+ * a zero-padding.
+ *
+ * @see ResizeOp for reszing images while stretching / compressing the content.
+ */
+public class ResizeWithCropOrPadOp implements ImageOperator {
+ private final int targetHeight;
+ private final int targetWidth;
+ private final Bitmap output;
+
+ /**
+ * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts
+ * center-crop and zero-padding.
+ *
+ * @param targetHeight: The expected height of cropped/padded image.
+ * @param targetWidth: The expected width of cropped/padded image.
+ */
+ public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) {
+ this.targetHeight = targetHeight;
+ this.targetWidth = targetWidth;
+ output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888);
+ }
+
+ /**
+ * Applies the defined resizing with cropping or/and padding on given image and returns the
+ * result.
+ *
+ * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
+ * with the output.
+ *
+ * @param image input image.
+ * @return output image.
+ */
+ @Override
+ @NonNull
+ public TensorImage apply(@NonNull TensorImage image) {
+ Bitmap input = image.getBitmap();
+ int srcL;
+ int srcR;
+ int srcT;
+ int srcB;
+ int dstL;
+ int dstR;
+ int dstT;
+ int dstB;
+ int w = input.getWidth();
+ int h = input.getHeight();
+ if (targetWidth > w) { // padding
+ srcL = 0;
+ srcR = w;
+ dstL = (targetWidth - w) / 2;
+ dstR = dstL + w;
+ } else { // cropping
+ dstL = 0;
+ dstR = targetWidth;
+ srcL = (w - targetWidth) / 2;
+ srcR = srcL + targetWidth;
+ }
+ if (targetHeight > h) { // padding
+ srcT = 0;
+ srcB = h;
+ dstT = (targetHeight - h) / 2;
+ dstB = dstT + h;
+ } else { // cropping
+ dstT = 0;
+ dstB = targetHeight;
+ srcT = (h - targetHeight) / 2;
+ srcB = srcT + targetHeight;
+ }
+ Rect src = new Rect(srcL, srcT, srcR, srcB);
+ Rect dst = new Rect(dstL, dstT, dstR, dstB);
+ new Canvas(output).drawBitmap(input, src, dst, null);
+ image.load(output);
+ return image;
+ }
+
+ @Override
+ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
+ return targetHeight;
+ }
+
+ @Override
+ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
+ return targetWidth;
+ }
+
+ @Override
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth);
+ }
+
+ private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) {
+ return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2);
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
new file mode 100644
index 00000000..2fa22937
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java
@@ -0,0 +1,103 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image.ops;
+
+import android.graphics.Bitmap;
+import android.graphics.Matrix;
+import android.graphics.PointF;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.support.image.ImageOperator;
+import org.tensorflow.lite.support.image.TensorImage;
+
+/** Rotates image counter-clockwise. */
+public class Rot90Op implements ImageOperator {
+
+ private final int numRotation;
+
+ /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */
+ public Rot90Op() {
+ this(1);
+ }
+
+ /**
+ * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times counter-clockwise.
+ *
+ * @param k: The number of times the image is rotated by 90 degrees. If it's positive, the image
+ * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise.
+ */
+ public Rot90Op(int k) {
+ numRotation = k % 4;
+ }
+
+ /**
+ * Applies the defined rotation on given image and returns the result.
+ *
+ * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance
+ * with the output.
+ *
+ * @param image input image.
+ * @return output image.
+ */
+ @NonNull
+ @Override
+ public TensorImage apply(@NonNull TensorImage image) {
+ Bitmap input = image.getBitmap();
+ if (numRotation == 0) {
+ return image;
+ }
+ int w = input.getWidth();
+ int h = input.getHeight();
+ Matrix matrix = new Matrix();
+ matrix.postTranslate(w * 0.5f, h * 0.5f);
+ matrix.postRotate(-90 * numRotation);
+ int newW = (numRotation % 2 == 0) ? w : h;
+ int newH = (numRotation % 2 == 0) ? h : w;
+ matrix.postTranslate(newW * 0.5f, newH * 0.5f);
+ Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false);
+ image.load(output);
+ return image;
+ }
+
+ @Override
+ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
+ return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth;
+ }
+
+ @Override
+ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
+ return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight;
+ }
+
+ @Override
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ int inverseNumRotation = (4 - numRotation) % 4;
+ int height = getOutputImageHeight(inputImageHeight, inputImageWidth);
+ int width = getOutputImageWidth(inputImageHeight, inputImageWidth);
+ return transformImpl(point, height, width, inverseNumRotation);
+ }
+
+ private static PointF transformImpl(PointF point, int height, int width, int numRotation) {
+ if (numRotation == 0) {
+ return point;
+ } else if (numRotation == 1) {
+ return new PointF(point.y, width - point.x);
+ } else if (numRotation == 2) {
+ return new PointF(width - point.x, height - point.y);
+ } else { // numRotation == 3
+ return new PointF(height - point.y, point.x);
+ }
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
new file mode 100644
index 00000000..420018dd
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java
@@ -0,0 +1,75 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.image.ops;
+
+import android.graphics.PointF;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+import org.tensorflow.lite.support.common.TensorOperator;
+import org.tensorflow.lite.support.image.ImageOperator;
+import org.tensorflow.lite.support.image.TensorImage;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * The adapter that makes a TensorOperator able to run with TensorImage.
+ *
+ * @see org.tensorflow.lite.support.common.TensorOperator
+ * @see org.tensorflow.lite.support.image.TensorImage
+ */
+public class TensorOperatorWrapper implements ImageOperator {
+
+ private final TensorOperator tensorOp;
+
+ /**
+ * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link
+ * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link
+ * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}.
+ *
+ * <p>Requirement: The {@code op} should not change coordinate system when applied on an image.
+ *
+ * @param op The created operator.
+ */
+ public TensorOperatorWrapper(TensorOperator op) {
+ tensorOp = op;
+ }
+
+ @Override
+ @NonNull
+ public TensorImage apply(@NonNull TensorImage image) {
+ SupportPreconditions.checkNotNull(image, "Op cannot apply on null image.");
+ TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer());
+ // Some ops may change the data type of the underlying TensorBuffer, such as CastOp. Therefore,
+ // need to create a new TensorImage with the correct data type.
+ TensorImage resImage = new TensorImage(resBuffer.getDataType());
+ resImage.load(resBuffer);
+ return resImage;
+ }
+
+ @Override
+ public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) {
+ return inputImageHeight;
+ }
+
+ @Override
+ public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) {
+ return inputImageWidth;
+ }
+
+ @Override
+ public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) {
+ return point;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
new file mode 100644
index 00000000..5b043a9f
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
@@ -0,0 +1,95 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.label;
+
+import java.util.Objects;
+import org.tensorflow.lite.annotations.UsedByReflection;
+
+/**
+ * Category is a util class, contains a label, its display name and a float value as score.
+ * Typically it's used as result of classification tasks.
+ */
+@UsedByReflection("TFLiteSupport/Task")
+public final class Category {
+ private final String label;
+ private final String displayName;
+ private final float score;
+
+ /**
+ * Constructs a {@link Category} object.
+ *
+ * @param displayName the display name of the label, which may be translated for different
+ * locales. For exmaple, a label, "apple", may be translated into Spanish for display purpose,
+ * so that the displayName is "manzana".
+ */
+ @UsedByReflection("TFLiteSupport/Task")
+ public static Category create(String label, String displayName, float score) {
+ return new Category(label, displayName, score);
+ }
+
+ @UsedByReflection("TFLiteSupport/Task")
+ /** Constructs a {@link Category} object with an empty displayName. */
+ public Category(String label, float score) {
+ this(label, /*displayName=*/ "", score);
+ }
+
+ private Category(String label, String displayName, float score) {
+ this.label = label;
+ this.displayName = displayName;
+ this.score = score;
+ }
+
+ /** Gets the reference of category's label. */
+ public String getLabel() {
+ return label;
+ }
+
+ /**
+ * Gets the reference of category's displayName, a name in locale of the label.
+ *
+ * <p>The display name can be an empty string if this {@link Category} object is constructed
+ * without displayName, such as when using {@link #Category(String label, float score)}.
+ */
+ public String getDisplayName() {
+ return displayName;
+ }
+
+ /** Gets the score of the category. */
+ public float getScore() {
+ return score;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof Category) {
+ Category other = (Category) o;
+ return (other.getLabel().equals(this.label)
+ && other.getDisplayName().equals(this.displayName)
+ && other.getScore() == this.score);
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(label, displayName, score);
+ }
+
+ @Override
+ public String toString() {
+ return "<Category \"" + label + "\" (displayName=" + displayName + "\" (score=" + score + ")>";
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
new file mode 100644
index 00000000..840ed5fb
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java
@@ -0,0 +1,64 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.label;
+
+import android.util.Log;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/** Label operation utils. */
+public class LabelUtil {
+ /**
+ * Maps an int value tensor to a list of string labels. It takes an array of strings as the
+ * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background",
+ * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"].
+ *
+ * @param tensorBuffer: A tensor with index values. The values should be non-negative integers,
+ * and each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is
+ * given as a float {@link TensorBuffer}, values will be cast to integers. All values that are
+ * out of bound will map to empty string.
+ * @param labels: A list of strings, used as a dictionary to look up. The index of the array
+ * element will be used as the key. To get better performance, use an object that implements
+ * RandomAccess, such as {@link ArrayList}.
+ * @param offset: The offset value when look up int values in the {@code labels}.
+ * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}.
+ * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null.
+ */
+ public static List<String> mapValueToLabels(
+ @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) {
+ SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null");
+ SupportPreconditions.checkNotNull(labels, "Given labels should not be null");
+ int[] values = tensorBuffer.getIntArray();
+ Log.d("values", Arrays.toString(values));
+ List<String> result = new ArrayList<>();
+ for (int v : values) {
+ int index = v + offset;
+ if (index < 0 || index >= labels.size()) {
+ result.add("");
+ } else {
+ result.add(labels.get(index));
+ }
+ }
+ return result;
+ }
+
+ // Private constructor to prevent initialization.
+ private LabelUtil() {}
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
new file mode 100644
index 00000000..10763a1a
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java
@@ -0,0 +1,224 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.label;
+
+import android.content.Context;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.
+ *
+ * <p>For example, an image classification model may have an output tensor with shape as {1, 10},
+ * where 1 is the batch size and 10 is the number of categories. In fact, on the 2nd axis, we could
+ * label each sub-tensor with the name or description of each corresponding category. {@link
+ * TensorLabel} could help converting the plain Tensor in {@link TensorBuffer} into a map from
+ * predefined labels to sub-tensors. In this case, if provided 10 labels for the 2nd axis, {@link
+ * TensorLabel} could convert the original {1, 10} Tensor to a 10 element map, each value of which
+ * is Tensor in shape {} (scalar). Usage example:
+ *
+ * <pre>
+ * TensorBuffer outputTensor = ...;
+ * {@literal List<String>} labels = FileUtil.loadLabels(context, labelFilePath);
+ * // labels the first axis with size greater than one
+ * TensorLabel labeled = new TensorLabel(labels, outputTensor);
+ * // If each sub-tensor has effectively size 1, we can directly get a float value
+ * {@literal Map<String, Float>} probabilities = labeled.getMapWithFloatValue();
+ * // Or get sub-tensors, when each sub-tensor has elements more than 1
+ * {@literal Map<String, TensorBuffer>} subTensors = labeled.getMapWithTensorBuffer();
+ * </pre>
+ *
+ * <p>Note: currently we only support tensor-to-map conversion for the first label with size greater
+ * than 1.
+ *
+ * @see org.tensorflow.lite.support.common.FileUtil#loadLabels(Context, String) to load labels from
+ * a label file (plain text file whose each line is a label) in assets simply.
+ */
+public class TensorLabel {
+ private final Map<Integer, List<String>> axisLabels;
+ private final TensorBuffer tensorBuffer;
+ private final int[] shape;
+
+ /**
+ * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
+ *
+ * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding
+ * labels. Note: The size of labels should be same with the size of the tensor on that axis.
+ * @param tensorBuffer The TensorBuffer to be labeled.
+ * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any
+ * value in {@code axisLabels} is null.
+ * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to
+ * the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code
+ * tensorBuffer} on the given dimension.
+ */
+ public TensorLabel(
+ @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) {
+ SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null.");
+ SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null.");
+ this.axisLabels = axisLabels;
+ this.tensorBuffer = tensorBuffer;
+ this.shape = tensorBuffer.getShape();
+ for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) {
+ int axis = entry.getKey();
+ SupportPreconditions.checkArgument(
+ axis >= 0 && axis < shape.length, "Invalid axis id: " + axis);
+ SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis);
+ SupportPreconditions.checkArgument(
+ shape[axis] == entry.getValue().size(),
+ "Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis);
+ }
+ }
+
+ /**
+ * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.
+ *
+ * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if
+ * the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from
+ * 0), and size of {@code axisLabels} should be 10 as well.
+ *
+ * @param axisLabels A list of labels, whose size should be same with the size of the tensor on
+ * the to-be-labeled axis.
+ * @param tensorBuffer The TensorBuffer to be labeled.
+ */
+ public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) {
+ this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer);
+ }
+
+ /**
+ * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the
+ * mapping on the first axis with size greater than 1 currently.
+ */
+ @NonNull
+ public Map<String, TensorBuffer> getMapWithTensorBuffer() {
+ int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
+
+ Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>();
+ SupportPreconditions.checkArgument(
+ axisLabels.containsKey(labeledAxis),
+ "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
+ List<String> labels = axisLabels.get(labeledAxis);
+
+ DataType dataType = tensorBuffer.getDataType();
+ int typeSize = tensorBuffer.getTypeSize();
+ int flatSize = tensorBuffer.getFlatSize();
+
+ // Gets the underlying bytes that could be used to generate the sub-array later.
+ ByteBuffer byteBuffer = tensorBuffer.getBuffer();
+ byteBuffer.rewind();
+
+ // Note: computation below is only correct when labeledAxis is the first axis with size greater
+ // than 1.
+ int subArrayLength = flatSize / shape[labeledAxis] * typeSize;
+ int i = 0;
+ SupportPreconditions.checkNotNull(labels, "Label list should never be null");
+ for (String label : labels) {
+ // Gets the corresponding TensorBuffer.
+ byteBuffer.position(i * subArrayLength);
+ ByteBuffer subBuffer = byteBuffer.slice();
+ // ByteBuffer.slice doesn't keep order. Modify it to align with the original one.
+ subBuffer.order(byteBuffer.order()).limit(subArrayLength);
+ TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
+ labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length));
+ labelToTensorMap.put(label, labelBuffer);
+ i += 1;
+ }
+ return labelToTensorMap;
+ }
+
+ /**
+ * Gets a map that maps label to float. Only allow the mapping on the first axis with size greater
+ * than 1, and the axis should be effectively the last axis (which means every sub tensor
+ * specified by this axis should have a flat size of 1).
+ *
+ * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
+ *
+ * @throws IllegalStateException if size of a sub tensor on each label is not 1.
+ */
+ @NonNull
+ public Map<String, Float> getMapWithFloatValue() {
+ int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
+ SupportPreconditions.checkState(
+ labeledAxis == shape.length - 1,
+ "get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
+ List<String> labels = axisLabels.get(labeledAxis);
+ float[] data = tensorBuffer.getFloatArray();
+ SupportPreconditions.checkState(labels.size() == data.length);
+ Map<String, Float> result = new LinkedHashMap<>();
+ int i = 0;
+ for (String label : labels) {
+ result.put(label, data[i]);
+ i += 1;
+ }
+ return result;
+ }
+
+ /**
+ * Gets a list of {@link Category} from the {@link TensorLabel} object.
+ *
+ * <p>The axis of label should be effectively the last axis (which means every sub tensor
+ * specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
+ * converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}}
+ * and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}.
+ *
+ * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
+ * the result.
+ *
+ * @throws IllegalStateException if size of a sub tensor on each label is not 1.
+ */
+ @NonNull
+ public List<Category> getCategoryList() {
+ int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
+ SupportPreconditions.checkState(
+ labeledAxis == shape.length - 1,
+ "get a Category list is only valid when the only labeled axis is the last one.");
+ List<String> labels = axisLabels.get(labeledAxis);
+ float[] data = tensorBuffer.getFloatArray();
+ SupportPreconditions.checkState(labels.size() == data.length);
+ List<Category> result = new ArrayList<>();
+ int i = 0;
+ for (String label : labels) {
+ result.add(new Category(label, data[i]));
+ i += 1;
+ }
+ return result;
+ }
+
+ private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
+ int[] shape = tensorBuffer.getShape();
+ for (int i = 0; i < shape.length; i++) {
+ if (shape[i] > 1) {
+ return i;
+ }
+ }
+ throw new IllegalArgumentException(
+ "Cannot find an axis to label. A valid axis to label should have size larger than 1.");
+ }
+
+ // Helper function to wrap the List<String> to a one-entry map.
+ private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) {
+ Map<Integer, List<String>> map = new LinkedHashMap<>();
+ map.put(axis, labels);
+ return map;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
new file mode 100644
index 00000000..c2de8c0b
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
@@ -0,0 +1,74 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.label.ops;
+
+import android.content.Context;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.support.common.FileUtil;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+import org.tensorflow.lite.support.label.TensorLabel;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * Labels TensorBuffer with axisLabels for outputs.
+ *
+ * <p>Apply on a {@code TensorBuffer} to get a {@code TensorLabel} that could output a Map, which is
+ * a pair of the label name and the corresponding TensorBuffer value.
+ */
+public class LabelAxisOp {
+ // Axis and its corresponding label names.
+ private final Map<Integer, List<String>> axisLabels;
+
+ protected LabelAxisOp(Builder builder) {
+ axisLabels = builder.axisLabels;
+ }
+
+ public TensorLabel apply(@NonNull TensorBuffer buffer) {
+ SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
+ return new TensorLabel(axisLabels, buffer);
+ }
+
+ /** The inner builder class to build a LabelTensor Operator. */
+ public static class Builder {
+ private final Map<Integer, List<String>> axisLabels;
+
+ protected Builder() {
+ axisLabels = new HashMap<>();
+ }
+
+ public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
+ throws IOException {
+ SupportPreconditions.checkNotNull(context, "Context cannot be null.");
+ SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
+ List<String> labels = FileUtil.loadLabels(context, filePath);
+ axisLabels.put(axis, labels);
+ return this;
+ }
+
+ public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
+ axisLabels.put(axis, labels);
+ return this;
+ }
+
+ public LabelAxisOp build() {
+ return new LabelAxisOp(this);
+ }
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
new file mode 100644
index 00000000..9cfcf923
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java
@@ -0,0 +1,69 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.model;
+
+import android.util.Log;
+import java.io.Closeable;
+import java.io.IOException;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.tensorflow.lite.Delegate;
+
+/**
+ * Helper class to create and call necessary methods of {@code GpuDelegate} which is not a strict
+ * dependency.
+ */
+class GpuDelegateProxy implements Delegate, Closeable {
+
+ private static final String TAG = "GpuDelegateProxy";
+
+ private final Delegate proxiedDelegate;
+ private final Closeable proxiedCloseable;
+
+ @Nullable
+ public static GpuDelegateProxy maybeNewInstance() {
+ try {
+ Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate");
+ Object instance = clazz.getDeclaredConstructor().newInstance();
+ return new GpuDelegateProxy(instance);
+ } catch (ReflectiveOperationException e) {
+ Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e);
+ return null;
+ }
+ }
+
+ /** Calls {@code close()} method of the delegate. */
+ @Override
+ public void close() {
+ try {
+ proxiedCloseable.close();
+ } catch (IOException e) {
+ // Should not trigger, because GpuDelegate#close never throws. The catch is required because
+ // of Closeable#close.
+ Log.e(TAG, "Failed to close the GpuDelegate.", e);
+ }
+ }
+
+ /** Calls {@code getNativeHandle()} method of the delegate. */
+ @Override
+ public long getNativeHandle() {
+ return proxiedDelegate.getNativeHandle();
+ }
+
+ private GpuDelegateProxy(Object instance) {
+ this.proxiedCloseable = (Closeable) instance;
+ this.proxiedDelegate = (Delegate) instance;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
new file mode 100644
index 00000000..8062d68d
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java
@@ -0,0 +1,285 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.model;
+
+import android.content.Context;
+import java.io.IOException;
+import java.nio.MappedByteBuffer;
+import java.util.Map;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.tensorflow.lite.Interpreter;
+import org.tensorflow.lite.Tensor;
+import org.tensorflow.lite.support.common.FileUtil;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+
+/**
+ * The wrapper class for a TFLite model and a TFLite interpreter.
+ *
+ * <p>Note: A {@link Model} can only holds 1 TFLite model at a time, and always holds a TFLite
+ * interpreter instance to run it.
+ */
+public class Model {
+
+ /** The runtime device type used for executing classification. */
+ public enum Device {
+ CPU,
+ NNAPI,
+ GPU
+ }
+
+ /**
+ * Options for running the model. Configurable parameters includes:
+ *
+ * <ul>
+ * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the model.
+ * The default value is {@link Device#CPU}.
+ * <li>{@code numThreads} {@link Builder#setNumThreads(int)} specifies the number of threads
+ * used by TFLite inference. It's only effective when device is set to {@link Device#CPU}
+ * and default value is 1.
+ * </ul>
+ */
+ public static class Options {
+ private final Device device;
+ private final int numThreads;
+
+ /** Builder of {@link Options}. See its doc for details. */
+ public static class Builder {
+ private Device device = Device.CPU;
+ private int numThreads = 1;
+
+ public Builder setDevice(Device device) {
+ this.device = device;
+ return this;
+ }
+
+ public Builder setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ public Options build() {
+ return new Options(this);
+ }
+ }
+
+ private Options(Builder builder) {
+ device = builder.device;
+ numThreads = builder.numThreads;
+ }
+ }
+
+ /** An instance of the driver class to run model inference with Tensorflow Lite. */
+ private final Interpreter interpreter;
+
+ /** Path to tflite model file in asset folder. */
+ private final String modelPath;
+
+ /** The memory-mapped model data. */
+ private final MappedByteBuffer byteModel;
+
+ private final GpuDelegateProxy gpuDelegateProxy;
+
+ /**
+ * Builder for {@link Model}.
+ *
+ * @deprecated Please use {@link Model#createModel(Context, String, Options)}.
+ */
+ @Deprecated
+ public static class Builder {
+ private Device device = Device.CPU;
+ private int numThreads = 1;
+ private final String modelPath;
+ private final MappedByteBuffer byteModel;
+
+ /**
+ * Creates a builder which loads tflite model from asset folder using memory-mapped files.
+ *
+ * @param context: Application context to access assets.
+ * @param modelPath: Asset path of the model (.tflite file).
+ * @throws IOException if an I/O error occurs when loading the tflite model.
+ */
+ @NonNull
+ public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException {
+ this.modelPath = modelPath;
+ byteModel = FileUtil.loadMappedFile(context, modelPath);
+ }
+
+ /** Sets running device. By default, TFLite will run on CPU. */
+ @NonNull
+ public Builder setDevice(Device device) {
+ this.device = device;
+ return this;
+ }
+
+ /** Sets number of threads. By default it's 1. */
+ @NonNull
+ public Builder setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ // Note: The implementation is copied from `Model#createModel`. As the builder is going to be
+ // deprecated, this function is also to be removed.
+ @NonNull
+ public Model build() {
+ Options options = new Options.Builder().setNumThreads(numThreads).setDevice(device).build();
+ return createModel(byteModel, modelPath, options);
+ }
+ }
+
+ /**
+ * Loads a model from assets and initialize TFLite interpreter.
+ *
+ * <p>The default options are: (1) CPU device; (2) one thread.
+ *
+ * @param context The App Context.
+ * @param modelPath The path of the model file.
+ * @throws IOException if any exception occurs when open the model file.
+ */
+ public static Model createModel(@NonNull Context context, @NonNull String modelPath)
+ throws IOException {
+ return createModel(context, modelPath, new Options.Builder().build());
+ }
+
+ /**
+ * Loads a model from assets and initialize TFLite interpreter with given options.
+ *
+ * @see Options for details.
+ * @param context The App Context.
+ * @param modelPath The path of the model file.
+ * @param options The options for running the model.
+ * @throws IOException if any exception occurs when open the model file.
+ */
+ public static Model createModel(
+ @NonNull Context context, @NonNull String modelPath, @NonNull Options options)
+ throws IOException {
+ SupportPreconditions.checkNotEmpty(
+ modelPath, "Model path in the asset folder cannot be empty.");
+ MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath);
+ return createModel(byteModel, modelPath, options);
+ }
+
+ /**
+ * Creates a model with loaded {@link MappedByteBuffer}.
+ *
+ * @see Options for details.
+ * @param byteModel The loaded TFLite model.
+ * @param modelPath The original path of the model. It can be fetched later by {@link
+ * Model#getPath()}.
+ * @param options The options for running the model.
+ * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but
+ * "tensorflow-lite-gpu" is not linked to the project.
+ */
+ public static Model createModel(
+ @NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) {
+ Interpreter.Options interpreterOptions = new Interpreter.Options();
+ GpuDelegateProxy gpuDelegateProxy = null;
+ switch (options.device) {
+ case NNAPI:
+ interpreterOptions.setUseNNAPI(true);
+ break;
+ case GPU:
+ gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance();
+ SupportPreconditions.checkArgument(
+ gpuDelegateProxy != null,
+ "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?");
+ interpreterOptions.addDelegate(gpuDelegateProxy);
+ break;
+ case CPU:
+ break;
+ }
+ interpreterOptions.setNumThreads(options.numThreads);
+ Interpreter interpreter = new Interpreter(byteModel, interpreterOptions);
+ return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy);
+ }
+
+ /** Returns the memory-mapped model data. */
+ @NonNull
+ public MappedByteBuffer getData() {
+ return byteModel;
+ }
+
+ /** Returns the path of the model file stored in Assets. */
+ @NonNull
+ public String getPath() {
+ return modelPath;
+ }
+
+ /**
+ * Gets the Tensor associated with the provdied input index.
+ *
+ * @throws IllegalStateException if the interpreter is closed.
+ */
+ public Tensor getInputTensor(int inputIndex) {
+ return interpreter.getInputTensor(inputIndex);
+ }
+
+ /**
+ * Gets the Tensor associated with the provdied output index.
+ *
+ * @throws IllegalStateException if the interpreter is closed.
+ */
+ public Tensor getOutputTensor(int outputIndex) {
+ return interpreter.getOutputTensor(outputIndex);
+ }
+
+ /**
+ * Returns the output shape. Useful if output shape is only determined when graph is created.
+ *
+ * @throws IllegalStateException if the interpreter is closed.
+ */
+ public int[] getOutputTensorShape(int outputIndex) {
+ return interpreter.getOutputTensor(outputIndex).shape();
+ }
+
+ /**
+ * Runs model inference on multiple inputs, and returns multiple outputs.
+ *
+ * @param inputs an array of input data. The inputs should be in the same order as inputs of the
+ * model. Each input can be an array or multidimensional array, or a {@link
+ * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link
+ * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types
+ * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer} is
+ * used, its content should remain unchanged until model inference is done.
+ * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
+ * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only
+ * needs to keep entries for the outputs to be used.
+ */
+ public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
+ interpreter.runForMultipleInputsOutputs(inputs, outputs);
+ }
+
+ public void close() {
+ if (interpreter != null) {
+ interpreter.close();
+ }
+ if (gpuDelegateProxy != null) {
+ gpuDelegateProxy.close();
+ }
+ }
+
+ private Model(
+ @NonNull String modelPath,
+ @NonNull MappedByteBuffer byteModel,
+ @NonNull Interpreter interpreter,
+ @Nullable GpuDelegateProxy gpuDelegateProxy) {
+ this.modelPath = modelPath;
+ this.byteModel = byteModel;
+ this.interpreter = interpreter;
+ this.gpuDelegateProxy = gpuDelegateProxy;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
new file mode 100644
index 00000000..446d0ea5
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java
@@ -0,0 +1,430 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.tensorbuffer;
+
+import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
+import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull;
+import static org.tensorflow.lite.support.common.SupportPreconditions.checkState;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.Arrays;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.DataType;
+
+/** Represents the data buffer for either a model's input or its output. */
+public abstract class TensorBuffer {
+ /** Where the data is stored. */
+ protected ByteBuffer buffer;
+
+ /** Shape of the tensor stored in this buffer. */
+ protected int[] shape;
+
+ /** Number of elements in the buffer. It will be changed to a proper value in the constructor. */
+ protected int flatSize = -1;
+
+ /**
+ * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have
+ * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed.
+ */
+ protected final boolean isDynamic;
+
+ /**
+ * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some
+ * examples:
+ *
+ * <pre>
+ * Creating a float TensorBuffer with shape {2, 3}:
+ * int[] shape = new int[] {2, 3};
+ * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
+ * </pre>
+ *
+ * <pre>
+ * Creating an uint8 TensorBuffer of a scalar:
+ * int[] shape = new int[] {};
+ * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ * </pre>
+ *
+ * <pre>
+ * Creating an empty uint8 TensorBuffer:
+ * int[] shape = new int[] {0};
+ * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8);
+ * </pre>
+ *
+ * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created.
+ *
+ * @param shape The shape of the {@link TensorBuffer} to be created.
+ * @param dataType The dataType of the {@link TensorBuffer} to be created.
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has non-positive elements.
+ */
+ @NonNull
+ public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) {
+ switch (dataType) {
+ case FLOAT32:
+ return new TensorBufferFloat(shape);
+ case UINT8:
+ return new TensorBufferUint8(shape);
+ default:
+ throw new AssertionError("TensorBuffer does not support data type: " + dataType);
+ }
+ }
+
+ /**
+ * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the
+ * created {@link TensorBuffer} is {0}.
+ *
+ * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of
+ * different buffer sizes.
+ *
+ * @param dataType The dataType of the {@link TensorBuffer} to be created.
+ */
+ @NonNull
+ public static TensorBuffer createDynamic(DataType dataType) {
+ switch (dataType) {
+ case FLOAT32:
+ return new TensorBufferFloat();
+ case UINT8:
+ return new TensorBufferUint8();
+ default:
+ throw new AssertionError("TensorBuffer does not support data type: " + dataType);
+ }
+ }
+
+ /**
+ * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}.
+ *
+ * @param buffer the source {@link TensorBuffer} to copy from.
+ * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}.
+ * @throws NullPointerException if {@code buffer} is null.
+ */
+ @NonNull
+ public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) {
+ checkNotNull(buffer, "Cannot create a buffer from null");
+ TensorBuffer result;
+ if (buffer.isDynamic()) {
+ result = createDynamic(dataType);
+ } else {
+ result = createFixedSize(buffer.shape, dataType);
+ }
+ // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as
+ // intermediate container.
+ // The assumption is not true when we support other data types.
+ if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) {
+ float[] data = buffer.getFloatArray();
+ result.loadArray(data, buffer.shape);
+ } else {
+ int[] data = buffer.getIntArray();
+ result.loadArray(data, buffer.shape);
+ }
+ return result;
+ }
+
+ /** Returns the data buffer. */
+ @NonNull
+ public ByteBuffer getBuffer() {
+ return buffer;
+ }
+
+ /**
+ * Gets the {@link TensorBuffer#flatSize} of the buffer.
+ *
+ * @throws IllegalStateException if the underlying data is corrupted
+ */
+ public int getFlatSize() {
+ assertShapeIsCorect();
+ return flatSize;
+ }
+
+ /**
+ * Gets the current shape. (returning a copy here to avoid unexpected modification.)
+ *
+ * @throws IllegalStateException if the underlying data is corrupted
+ */
+ @NonNull
+ public int[] getShape() {
+ assertShapeIsCorect();
+ return Arrays.copyOf(shape, shape.length);
+ }
+
+ /** Returns the data type of this buffer. */
+ public abstract DataType getDataType();
+
+ /**
+ * Returns a float array of the values stored in this buffer. If the buffer is of different types
+ * than float, the values will be converted into float. For example, values in {@link
+ * TensorBufferUint8} will be converted from uint8 to float.
+ */
+ @NonNull
+ public abstract float[] getFloatArray();
+
+ /**
+ * Returns a float value at a given index. If the buffer is of different types than float, the
+ * value will be converted into float. For example, when reading a value from {@link
+ * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from
+ * uint8 to float.
+ *
+ * <pre>
+ * For example, a TensorBuffer with shape {2, 3} that represents the following array,
+ * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
+ *
+ * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by:
+ * float v = tensorBuffer.getFloatValue(3);
+ * </pre>
+ *
+ * @param absIndex The absolute index of the value to be read.
+ */
+ public abstract float getFloatValue(int absIndex);
+
+ /**
+ * Returns an int array of the values stored in this buffer. If the buffer is of different type
+ * than int, the values will be converted into int, and loss of precision may apply. For example,
+ * getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output
+ * is {400, 23}.
+ */
+ @NonNull
+ public abstract int[] getIntArray();
+
+ /**
+ * Returns an int value at a given index. If the buffer is of different types than int, the value
+ * will be converted into int. For example, when reading a value from {@link TensorBufferFloat},
+ * the value will be first read out as float, and then will be converted from float to int. Loss
+ * of precision may apply.
+ *
+ * <pre>
+ * For example, a TensorBuffer with shape {2, 3} that represents the following array,
+ * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]].
+ *
+ * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by:
+ * int v = tensorBuffer.getIntValue(3);
+ * Note that v is converted from 3.0f to 3 as a result of type conversion.
+ * </pre>
+ *
+ * @param absIndex The absolute index of the value to be read.
+ */
+ public abstract int getIntValue(int absIndex);
+
+ /**
+ * Returns the number of bytes of a single element in the array. For example, a float buffer will
+ * return 4, and a byte buffer will return 1.
+ */
+ public abstract int getTypeSize();
+
+ /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */
+ public boolean isDynamic() {
+ return isDynamic;
+ }
+
+ /**
+ * Loads an int array into this buffer with specific shape. If the buffer is of different types
+ * than int, the values will be converted into the buffer's type before being loaded into the
+ * buffer, and loss of precision may apply. For example, loading an int array with values {400,
+ * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be
+ * casted to uint8 by {255, 0}.
+ *
+ * @param src The source array to be loaded.
+ * @param shape Shape of the tensor that {@code src} represents.
+ * @throws NullPointerException if {@code src} is null.
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if the size of the array to be loaded does not match the
+ * specified shape.
+ */
+ public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape);
+
+ /**
+ * Loads an int array into this buffer. If the buffer is of different types than int, the values
+ * will be converted into the buffer's type before being loaded into the buffer, and loss of
+ * precision may apply. For example, loading an int array with values {400, -23} into a {@link
+ * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by
+ * {255, 0}.
+ *
+ * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both
+ * fixed-size and dynamic {@link TensorBuffer}.
+ *
+ * @param src The source array to be loaded.
+ */
+ public void loadArray(@NonNull int[] src) {
+ loadArray(src, shape);
+ }
+
+ /**
+ * Loads a float array into this buffer with specific shape. If the buffer is of different types
+ * than float, the values will be converted into the buffer's type before being loaded into the
+ * buffer, and loss of precision may apply. For example, loading a float array into a {@link
+ * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and
+ * then be casted to uint8 by {255, 0}.
+ *
+ * @param src The source array to be loaded.
+ * @param shape Shape of the tensor that {@code src} represents.
+ * @throws NullPointerException if {@code src} is null.
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if the size of the array to be loaded does not match the
+ * specified shape.
+ */
+ public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape);
+
+ /**
+ * Loads a float array into this buffer. If the buffer is of different types than float, the
+ * values will be converted into the buffer's type before being loaded into the buffer, and loss
+ * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8}
+ * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to
+ * uint8 by {255, 0}.
+ *
+ * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both
+ * fixed-size and dynamic {@link TensorBuffer}.
+ *
+ * @param src The source array to be loaded.
+ */
+ public void loadArray(@NonNull float[] src) {
+ loadArray(src, shape);
+ }
+
+ /**
+ * Loads a byte buffer into this {@link TensorBuffer} with specific shape.
+ *
+ * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
+ * performance concern, but if modification is necessary, please make a copy.
+ *
+ * @param buffer The byte buffer to load.
+ * @throws NullPointerException if {@code buffer} is null.
+ * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not
+ * match or the size of {@code buffer} and {@code flatSize} do not match.
+ */
+ public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) {
+ checkNotNull(buffer, "Byte buffer cannot be null.");
+ int flatSize = computeFlatSize(shape);
+ checkArgument(
+ (buffer.limit() == getTypeSize() * flatSize),
+ "The size of byte buffer and the shape do not match.");
+
+ resize(shape);
+ buffer.rewind();
+ this.buffer = buffer;
+ }
+
+ /**
+ * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of
+ * this {@link TensorBuffer}.
+ *
+ * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for
+ * performance concern, but if modification is necessary, please make a copy.
+ *
+ * @param buffer The byte buffer to load.
+ */
+ public void loadBuffer(@NonNull ByteBuffer buffer) {
+ loadBuffer(buffer, shape);
+ }
+
+ /**
+ * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}.
+ *
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has non-positive elements.
+ */
+ protected TensorBuffer(@NonNull int[] shape) {
+ isDynamic = false;
+ allocateMemory(shape);
+ }
+
+ /** Constructs a dynamic {@link TensorBuffer} which can be resized. */
+ protected TensorBuffer() {
+ isDynamic = true;
+ // Initialize the dynamic TensorBuffer with an empty ByteBuffer.
+ allocateMemory(new int[] {0});
+ }
+
+ /** Calculates number of elements in the buffer. */
+ protected static int computeFlatSize(@NonNull int[] shape) {
+ checkNotNull(shape, "Shape cannot be null.");
+ int prod = 1;
+ for (int s : shape) {
+ prod = prod * s;
+ }
+ return prod;
+ }
+
+ /**
+ * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code
+ * shape} of src fits the buffer size.
+ */
+ protected void resize(@NonNull int[] shape) {
+ if (isDynamic) {
+ allocateMemory(shape);
+ } else {
+ // Make sure the new shape fits the buffer size when TensorBuffer has fixed size.
+ checkArgument(Arrays.equals(shape, this.shape));
+ this.shape = shape.clone();
+ }
+ }
+
+ /**
+ * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this
+ * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1.
+ *
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has negative elements.
+ */
+ private void allocateMemory(@NonNull int[] shape) {
+ checkNotNull(shape, "TensorBuffer shape cannot be null.");
+ checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative.");
+
+ // Check if the new shape is the same as current shape.
+ int newFlatSize = computeFlatSize(shape);
+ this.shape = shape.clone();
+ if (flatSize == newFlatSize) {
+ return;
+ }
+
+ // Update to the new shape.
+ flatSize = newFlatSize;
+ buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize());
+ buffer.order(ByteOrder.nativeOrder());
+ }
+
+ /**
+ * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link
+ * ByteBuffer}.
+ */
+ private void assertShapeIsCorect() {
+ int flatSize = computeFlatSize(shape);
+ checkState(
+ (buffer.limit() == getTypeSize() * flatSize),
+ String.format(
+ "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The"
+ + " ByteBuffer may have been changed.",
+ buffer.limit(), Arrays.toString(shape)));
+ }
+
+ /**
+ * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape}
+ * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar.
+ */
+ private static boolean isShapeValid(@NonNull int[] shape) {
+ if (shape.length == 0) {
+ // This shape refers to a scalar.
+ return true;
+ }
+
+ // This shape refers to a multidimensional array.
+ for (int s : shape) {
+ // All elements in shape should be non-negative.
+ if (s < 0) {
+ return false;
+ }
+ }
+ return true;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
new file mode 100644
index 00000000..65bbd7d0
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java
@@ -0,0 +1,115 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.tensorbuffer;
+
+import java.nio.FloatBuffer;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+
+/** Represents data buffer with float values. */
+public final class TensorBufferFloat extends TensorBuffer {
+ private static final DataType DATA_TYPE = DataType.FLOAT32;
+
+ /**
+ * Creates a {@link TensorBufferFloat} with specified {@code shape}.
+ *
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has non-positive elements.
+ */
+ TensorBufferFloat(@NonNull int[] shape) {
+ super(shape);
+ }
+
+ TensorBufferFloat() {
+ super();
+ }
+
+ @Override
+ public DataType getDataType() {
+ return DATA_TYPE;
+ }
+
+ @Override
+ @NonNull
+ public float[] getFloatArray() {
+ buffer.rewind();
+ float[] arr = new float[flatSize];
+
+ FloatBuffer floatBuffer = buffer.asFloatBuffer();
+ floatBuffer.get(arr);
+ return arr;
+ }
+
+ @Override
+ public float getFloatValue(int absIndex) {
+ return buffer.getFloat(absIndex << 2);
+ }
+
+ @Override
+ @NonNull
+ public int[] getIntArray() {
+ buffer.rewind();
+ float[] floatArr = new float[flatSize];
+ buffer.asFloatBuffer().get(floatArr);
+
+ int[] intArr = new int[flatSize];
+ for (int i = 0; i < flatSize; i++) {
+ intArr[i] = (int) floatArr[i];
+ }
+ return intArr;
+ }
+
+ @Override
+ public int getIntValue(int absIndex) {
+ return (int) buffer.getFloat(absIndex << 2);
+ }
+
+ @Override
+ public int getTypeSize() {
+ return DATA_TYPE.byteSize();
+ }
+
+ @Override
+ public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
+ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
+ SupportPreconditions.checkArgument(
+ src.length == computeFlatSize(shape),
+ "The size of the array to be loaded does not match the specified shape.");
+ resize(shape);
+ buffer.rewind();
+
+ FloatBuffer floatBuffer = buffer.asFloatBuffer();
+ floatBuffer.put(src);
+ }
+
+ @Override
+ public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
+ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
+ SupportPreconditions.checkArgument(
+ src.length == computeFlatSize(shape),
+ "The size of the array to be loaded does not match the specified shape.");
+ resize(shape);
+ buffer.rewind();
+
+ float[] floatArray = new float[src.length];
+ int cnt = 0;
+ for (int a : src) {
+ floatArray[cnt++] = (float) a;
+ }
+ buffer.asFloatBuffer().put(floatArray);
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
new file mode 100644
index 00000000..33641940
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java
@@ -0,0 +1,121 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.tensorbuffer;
+
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+
+/** Represents data buffer with 8-bit unsigned integer values. */
+public final class TensorBufferUint8 extends TensorBuffer {
+ private static final DataType DATA_TYPE = DataType.UINT8;
+
+ /**
+ * Creates a {@link TensorBufferUint8} with specified {@code shape}.
+ *
+ * @throws NullPointerException if {@code shape} is null.
+ * @throws IllegalArgumentException if {@code shape} has non-positive elements.
+ */
+ TensorBufferUint8(@NonNull int[] shape) {
+ super(shape);
+ }
+
+ TensorBufferUint8() {
+ super();
+ }
+
+ @Override
+ public DataType getDataType() {
+ return DATA_TYPE;
+ }
+
+ @Override
+ @NonNull
+ public float[] getFloatArray() {
+ buffer.rewind();
+ byte[] byteArr = new byte[flatSize];
+ buffer.get(byteArr);
+
+ float[] floatArr = new float[flatSize];
+ for (int i = 0; i < flatSize; i++) {
+ floatArr[i] = (float) (byteArr[i] & 0xff);
+ }
+ return floatArr;
+ }
+
+ @Override
+ public float getFloatValue(int index) {
+ return (float) (buffer.get(index) & 0xff);
+ }
+
+ @Override
+ @NonNull
+ public int[] getIntArray() {
+ buffer.rewind();
+ byte[] byteArr = new byte[flatSize];
+ buffer.get(byteArr);
+
+ int[] intArr = new int[flatSize];
+ for (int i = 0; i < flatSize; i++) {
+ intArr[i] = byteArr[i] & 0xff;
+ }
+ return intArr;
+ }
+
+ @Override
+ public int getIntValue(int index) {
+ return buffer.get(index) & 0xff;
+ }
+
+ @Override
+ public int getTypeSize() {
+ return DATA_TYPE.byteSize();
+ }
+
+ @Override
+ public void loadArray(@NonNull float[] src, @NonNull int[] shape) {
+ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
+ SupportPreconditions.checkArgument(
+ src.length == computeFlatSize(shape),
+ "The size of the array to be loaded does not match the specified shape.");
+ resize(shape);
+ buffer.rewind();
+
+ byte[] byteArr = new byte[src.length];
+ int cnt = 0;
+ for (float a : src) {
+ byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0);
+ }
+ buffer.put(byteArr);
+ }
+
+ @Override
+ public void loadArray(@NonNull int[] src, @NonNull int[] shape) {
+ SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null.");
+ SupportPreconditions.checkArgument(
+ src.length == computeFlatSize(shape),
+ "The size of the array to be loaded does not match the specified shape.");
+ resize(shape);
+ buffer.rewind();
+
+ byte[] byteArr = new byte[src.length];
+ int cnt = 0;
+ for (float a : src) {
+ byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0);
+ }
+ buffer.put(byteArr);
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD
new file mode 100644
index 00000000..f82b8009
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD
@@ -0,0 +1,22 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+android_library(
+ name = "base-task-api",
+ srcs = glob(["**/*.java"]),
+ javacopts = JAVACOPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ "@com_google_auto_value",
+ ],
+)
+
+alias(
+ name = "base_task_api",
+ actual = ":base-task-api",
+)
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
new file mode 100644
index 00000000..b3fe9def
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java
@@ -0,0 +1,91 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.core;
+
+import android.util.Log;
+import java.io.Closeable;
+
+/**
+ * Base class for Task API, provides shared logic to load/unload native libs to its C++ counterpart.
+ */
+public abstract class BaseTaskApi implements Closeable {
+ private static final String TAG = BaseTaskApi.class.getSimpleName();
+
+ /**
+ * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is
+ * initialized from subclasses and must be released by calling {@link #deinit} after it is no
+ * longer needed.
+ */
+ private final long nativeHandle;
+
+ /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */
+ private boolean closed;
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++.
+ */
+ protected BaseTaskApi(long nativeHandle) {
+ if (nativeHandle == TaskJniUtils.INVALID_POINTER) {
+ throw new IllegalArgumentException("Failed to load C++ pointer from JNI");
+ }
+ this.nativeHandle = nativeHandle;
+ }
+
+ public boolean isClosed() {
+ return closed;
+ }
+
+ /** Release the memory allocated from C++ and deregister the library from the static holder. */
+ @Override
+ public synchronized void close() {
+ if (closed) {
+ return;
+ }
+ deinit(nativeHandle);
+ closed = true;
+ }
+
+ public long getNativeHandle() {
+ return nativeHandle;
+ }
+
+ protected void checkNotClosed() {
+ if (isClosed()) {
+ throw new IllegalStateException("Internal error: The task lib has already been closed.");
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ if (!closed) {
+ Log.w(TAG, "Closing an already closed native lib");
+ close();
+ }
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /**
+ * Releases memory pointed by the pointer in the native layer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ protected abstract void deinit(long nativeHandle);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
new file mode 100644
index 00000000..f5c52a03
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java
@@ -0,0 +1,165 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.core;
+
+import android.content.Context;
+import android.content.res.AssetFileDescriptor;
+import android.util.Log;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+
+/** JNI utils for Task API. */
+public class TaskJniUtils {
+ public static final long INVALID_POINTER = 0;
+ private static final String TAG = TaskJniUtils.class.getSimpleName();
+ /** Syntax sugar to get nativeHandle from empty param list. */
+ public interface EmptyHandleProvider {
+ long createHandle();
+ }
+
+ /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */
+ public interface MultipleBuffersHandleProvider {
+ long createHandle(ByteBuffer... buffers);
+ }
+
+ /** Syntax sugar to get nativeHandle from file descriptor and options. */
+ public interface FdAndOptionsHandleProvider<T> {
+ long createHandle(
+ int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options);
+ }
+
+ /**
+ * Initializes the JNI and returns C++ handle with file descriptor and options for task API.
+ *
+ * @param context the Android app context
+ * @param provider provider to get C++ handle, usually returned from native call
+ * @param libName name of C++ lib to be loaded
+ * @param filePath path of the file to be loaded
+ * @param options options to set up the task API, used by the provider
+ * @return C++ handle as long
+ * @throws IOException If model file fails to load.
+ */
+ public static <T> long createHandleFromFdAndOptions(
+ Context context,
+ final FdAndOptionsHandleProvider<T> provider,
+ String libName,
+ String filePath,
+ final T options)
+ throws IOException {
+ try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) {
+ return createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return provider.createHandle(
+ /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
+ /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
+ options);
+ }
+ },
+ libName);
+ }
+ }
+
+ /**
+ * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
+ * {@link EmptyHandleProvider#createHandle()}.
+ *
+ * @param provider provider to get C++ handle, usually returned from native call
+ * @return C++ handle as long
+ */
+ public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) {
+ tryLoadLibrary(libName);
+ try {
+ return provider.createHandle();
+ } catch (Exception e) {
+ String errorMessage = "Error getting native address of native library: " + libName;
+ Log.e(TAG, errorMessage, e);
+ throw new IllegalStateException(errorMessage, e);
+ }
+ }
+
+ /**
+ * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes
+ * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}.
+ *
+ * @param context app context
+ * @param provider provider to get C++ pointer, usually returned from native call
+ * @param libName name of C++ lib to load
+ * @param filePaths file paths to load
+ * @return C++ pointer as long
+ * @throws IOException If model file fails to load.
+ */
+ public static long createHandleWithMultipleAssetFilesFromLibrary(
+ Context context,
+ final MultipleBuffersHandleProvider provider,
+ String libName,
+ String... filePaths)
+ throws IOException {
+ final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length];
+ for (int i = 0; i < filePaths.length; i++) {
+ buffers[i] = loadMappedFile(context, filePaths[i]);
+ }
+ return createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return provider.createHandle(buffers);
+ }
+ },
+ libName);
+ }
+
+ /**
+ * Loads a file from the asset folder through memory mapping.
+ *
+ * @param context Application context to access assets.
+ * @param filePath Asset path of the file.
+ * @return the loaded memory mapped file.
+ * @throws IOException If model file fails to load.
+ */
+ public static MappedByteBuffer loadMappedFile(Context context, String filePath)
+ throws IOException {
+ try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);
+ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = fileDescriptor.getStartOffset();
+ long declaredLength = fileDescriptor.getDeclaredLength();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
+ }
+
+ private TaskJniUtils() {}
+
+ /**
+ * Try load a native library, if it's already loaded return directly.
+ *
+ * @param libName name of the lib
+ */
+ static void tryLoadLibrary(String libName) {
+ try {
+ System.loadLibrary(libName);
+ } catch (UnsatisfiedLinkError e) {
+ String errorMessage = "Error loading native library: " + libName;
+ Log.e(TAG, errorMessage, e);
+ throw new UnsatisfiedLinkError(errorMessage);
+ }
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
new file mode 100644
index 00000000..0236f2ce
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java
@@ -0,0 +1,117 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.core.vision;
+
+import android.graphics.Rect;
+import com.google.auto.value.AutoValue;
+
+/**
+ * Options to configure the image processing pipeline, which operates before inference.
+ *
+ * <p>The Task Library Vision API performs image preprocessing on the input image over the region of
+ * interest, so that it fits model requirements (e.g. upright 224x224 RGB) and populate the
+ * corresponding input tensor. This is performed by (in this order):
+ *
+ * <ul>
+ * <li>cropping the frame buffer to the region of interest (which, in most cases, just covers the
+ * entire input image),
+ * <li>resizing it (with bilinear interpolation, aspect-ratio *not* preserved) to the dimensions
+ * of the model input tensor,
+ * <li>converting it to the colorspace of the input tensor (i.e. RGB, which is the only supported
+ * colorspace for now),
+ * <li>rotating it according to its {@link Orientation} so that inference is performed on an
+ * "upright" image.
+ * </ul>
+ *
+ * <p>IMPORTANT: as a consequence of cropping occurring first, the provided region of interest is
+ * expressed in the unrotated frame of reference coordinates system, i.e. in {@code [0,
+ * TensorImage.getWidth()) x [0, TensorImage.getHeight())}, which are the dimensions of the
+ * underlying image data before any orientation gets applied. If the region is out of these bounds,
+ * the inference method, such as {@link ImageClassifier#classify}, will return error.
+ */
+@AutoValue
+public abstract class ImageProcessingOptions {
+
+ /**
+ * Orientation type that follows EXIF specification.
+ *
+ * <p>The name of each enum value defines the position of the 0th row and the 0th column of the
+ * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation
+ * documentation</a> for details.
+ */
+ public enum Orientation {
+ TOP_LEFT(0),
+ TOP_RIGHT(1),
+ BOTTOM_RIGHT(2),
+ BOTTOM_LEFT(3),
+ LEFT_TOP(4),
+ RIGHT_TOP(5),
+ RIGHT_BOTTOM(6),
+ LEFT_BOTTOM(7);
+
+ private final int value;
+
+ Orientation(int value) {
+ this.value = value;
+ }
+
+ public int getValue() {
+ return value;
+ }
+ };
+
+ private static final Rect defaultRoi = new Rect();
+ private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT;
+
+ public abstract Rect getRoi();
+
+ public abstract Orientation getOrientation();
+
+ public static Builder builder() {
+ return new AutoValue_ImageProcessingOptions.Builder()
+ .setRoi(defaultRoi)
+ .setOrientation(DEFAULT_ORIENTATION);
+ }
+
+ /** Builder for {@link ImageProcessingOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+
+ /**
+ * Sets the region of interest (ROI) of the image. Defaults to the entire image.
+ *
+ * <p>Cropping according to this region of interest is prepended to the pre-processing
+ * operations.
+ */
+ public abstract Builder setRoi(Rect roi);
+
+ /**
+ * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}.
+ *
+ * <p>Rotation will be applied accordingly so that inference is performed on an "upright" image.
+ */
+ public abstract Builder setOrientation(Orientation orientation);
+
+ abstract Rect getRoi();
+
+ abstract ImageProcessingOptions autoBuild();
+
+ public ImageProcessingOptions build() {
+ setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable.
+ return autoBuild();
+ }
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml
new file mode 100644
index 00000000..d4d1dbad
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="org.tensorflow.lite.task.text">
+ <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/>
+</manifest>
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD
new file mode 100644
index 00000000..695e1bef
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD
@@ -0,0 +1,37 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
+load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(["AndroidManifest.xml"])
+
+android_library(
+ name = "task_library_text",
+ srcs = [
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier:nl_classifier_src",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa:bert_question_answerer_src",
+ ],
+ javacopts = JAVACOPTS,
+ manifest = "AndroidManifest.xml",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_lite_support/java:tensorflowlite_support_java",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
+ "//tensorflow_lite_support/java/src/native/task/text:task_text_native",
+ "@com_google_auto_value",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
+ ],
+)
+
+# AAR target for OSS release.
+#
+# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text:task-library-text
+aar_with_jni(
+ name = "task-library-text",
+ android_library = ":task_library_text",
+)
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD
new file mode 100644
index 00000000..a1d78d8f
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD
@@ -0,0 +1,79 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+filegroup(
+ name = "nl_classifier_src",
+ srcs = glob(["**/*.java"]),
+)
+
+# Java-only target, need to be used together with a native target similar to
+# third_party/tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_native.
+# Use this target when you want to provide a MutableOpResolver with customized
+# OPs and/or a subset of BuiltInOps to reduce binary size.
+android_library(
+ name = "nl_classifier_java",
+ srcs = [
+ "NLClassifier.java",
+ ],
+ javacopts = JAVACOPTS,
+ deps = [
+ "//tensorflow_lite_support/java:tensorflowlite_support_java",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
+ "@com_google_auto_value",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
+ ],
+)
+
+# Default target that uses BuiltInOpResolver, registers all built-in OPs.
+android_library(
+ name = "nl_classifier",
+ srcs = [
+ "NLClassifier.java",
+ ],
+ javacopts = JAVACOPTS,
+ deps = [
+ "//tensorflow_lite_support/java:tensorflowlite_support_java",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
+ "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_native",
+ "@com_google_auto_value",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
+ ],
+)
+
+# AAR target for OSS release.
+#
+# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier:nl-classifier
+aar_with_jni(
+ name = "nl-classifier",
+ android_library = ":nl_classifier",
+)
+
+# Default target that uses BuiltInOpResolver, registers all built-in OPs.
+android_library(
+ name = "bert_nl_classifier",
+ srcs = [
+ "BertNLClassifier.java",
+ ],
+ javacopts = JAVACOPTS,
+ deps = [
+ "//tensorflow_lite_support/java:tensorflowlite_support_java",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
+ "//tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier:bert_nl_classifier_native",
+ ],
+)
+
+# AAR target for OSS release.
+#
+# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier:bert-nl-classifier
+aar_with_jni(
+ name = "bert-nl-classifier",
+ android_library = ":bert_nl_classifier",
+)
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
new file mode 100644
index 00000000..90bea370
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
@@ -0,0 +1,142 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.text.nlclassifier;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.util.List;
+import org.tensorflow.lite.support.label.Category;
+import org.tensorflow.lite.task.core.BaseTaskApi;
+import org.tensorflow.lite.task.core.TaskJniUtils;
+import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
+
+/**
+ * Classifier API for NLClassification tasks with Bert models, categorizes string into different
+ * classes. The API expects a Bert based TFLite model with metadata populated.
+ *
+ * <p>The metadata should contain the following information:
+ *
+ * <ul>
+ * <li>1 input_process_unit for Wordpiece/Sentencepiece Tokenizer.
+ * <li>3 input tensors with names "ids", "mask" and "segment_ids".
+ * <li>1 output tensor of type float32[1, 2], with a optionally attached label file. If a label
+ * file is attached, the file should be a plain text file with one label per line, the number
+ * of labels should match the number of categories the model outputs.
+ * </ul>
+ */
+public class BertNLClassifier extends BaseTaskApi {
+ private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++.
+ */
+ private BertNLClassifier(long nativeHandle) {
+ super(nativeHandle);
+ }
+
+ /**
+ * Create {@link BertNLClassifier} from a model file with metadata.
+ *
+ * @param context Android context
+ * @param pathToModel Path to the classification model.
+ * @return {@link BertNLClassifier} instance.
+ * @throws IOException If model file fails to load.
+ */
+ public static BertNLClassifier createFromFile(final Context context, final String pathToModel)
+ throws IOException {
+ return createFromBuffer(TaskJniUtils.loadMappedFile(context, pathToModel));
+ }
+
+ /**
+ * Create {@link BertNLClassifier} from a {@link File} object with metadata.
+ *
+ * @param modelFile The classification model {@link File} instance.
+ * @return {@link BertNLClassifier} instance.
+ * @throws IOException If model file fails to load.
+ */
+ public static BertNLClassifier createFromFile(File modelFile) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new BertNLClassifier(
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithFileDescriptor(descriptor.getFd());
+ }
+ },
+ BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
+ }
+ }
+
+ /**
+ * Create {@link BertNLClassifier} with a model buffer.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model
+ * @return {@link BertNLClassifier} instance
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new BertNLClassifier(
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(modelBuffer);
+ }
+ },
+ BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
+ }
+
+ /**
+ * Perform classification on a string input, returns classified {@link Category}s.
+ *
+ * @param text input text to the model.
+ * @return A list of Category results.
+ */
+ public List<Category> classify(String text) {
+ return classifyNative(getNativeHandle(), text);
+ }
+
+ private static native long initJniWithByteBuffer(ByteBuffer modelBuffer);
+
+ private static native long initJniWithFileDescriptor(int fd);
+
+ private static native List<Category> classifyNative(long nativeHandle, String text);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
+ }
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
new file mode 100644
index 00000000..2bc20d8c
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
@@ -0,0 +1,257 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.text.nlclassifier;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import com.google.auto.value.AutoValue;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.util.List;
+import org.tensorflow.lite.annotations.UsedByReflection;
+import org.tensorflow.lite.support.label.Category;
+import org.tensorflow.lite.task.core.BaseTaskApi;
+import org.tensorflow.lite.task.core.TaskJniUtils;
+import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
+
+/**
+ * Classifier API for natural language classification tasks, categorizes string into different
+ * classes.
+ *
+ * <p>The API expects a TFLite model with the following input/output tensor:
+ *
+ * <ul>
+ * <li>Input tensor (kTfLiteString)
+ * <ul>
+ * <li>input of the model, accepts a string.
+ * </ul>
+ * <li>Output score tensor
+ * (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool)
+ * <ul>
+ * <li>output scores for each class, if type is one of the Int types, dequantize it, if it
+ * is Bool type, convert the values to 0.0 and 1.0 respectively.
+ * <li>can have an optional associated file in metadata for labels, the file should be a
+ * plain text file with one label per line, the number of labels should match the number
+ * of categories the model outputs. Output label tensor: optional (kTfLiteString) -
+ * output classname for each class, should be of the same length with scores. If this
+ * tensor is not present, the API uses score indices as classnames. - will be ignored if
+ * output score tensor already has an associated label file.
+ * </ul>
+ * <li>Optional Output label tensor (kTfLiteString/kTfLiteInt32)
+ * <ul>
+ * <li>output classname for each class, should be of the same length with scores. If this
+ * tensor is not present, the API uses score indices as classnames.
+ * <li>will be ignored if output score tensor already has an associated labe file.
+ * </ul>
+ * </ul>
+ *
+ * <p>By default the API tries to find the input/output tensors with default configurations in
+ * {@link NLClassifierOptions}, with tensor name prioritized over tensor index. The option is
+ * configurable for different TFLite models.
+ */
+public class NLClassifier extends BaseTaskApi {
+
+ /** Options to identify input and output tensors of the model. */
+ @AutoValue
+ @UsedByReflection("nl_classifier_jni.cc")
+ public abstract static class NLClassifierOptions {
+ private static final int DEFAULT_INPUT_TENSOR_INDEX = 0;
+ private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0;
+ // By default there is no output label tensor. The label file can be attached
+ // to the output score tensor metadata.
+ private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1;
+ private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT";
+ private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE";
+ private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL";
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract int inputTensorIndex();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract int outputScoreTensorIndex();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract int outputLabelTensorIndex();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract String inputTensorName();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract String outputScoreTensorName();
+
+ @UsedByReflection("nl_classifier_jni.cc")
+ abstract String outputLabelTensorName();
+
+ public static Builder builder() {
+ return new AutoValue_NLClassifier_NLClassifierOptions.Builder()
+ .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX)
+ .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX)
+ .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX)
+ .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME)
+ .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME)
+ .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME);
+ }
+
+ /** Builder for {@link NLClassifierOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ public abstract Builder setInputTensorIndex(int value);
+
+ public abstract Builder setOutputScoreTensorIndex(int value);
+
+ public abstract Builder setOutputLabelTensorIndex(int value);
+
+ public abstract Builder setInputTensorName(String value);
+
+ public abstract Builder setOutputScoreTensorName(String value);
+
+ public abstract Builder setOutputLabelTensorName(String value);
+
+ public abstract NLClassifierOptions build();
+ }
+ }
+
+ private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++.
+ */
+ protected NLClassifier(long nativeHandle) {
+ super(nativeHandle);
+ }
+
+ /**
+ * Create {@link NLClassifier} from default {@link NLClassifierOptions}.
+ *
+ * @param context Android context.
+ * @param pathToModel Path to the classification model relative to asset dir.
+ * @return {@link NLClassifier} instance.
+ * @throws IOException If model file fails to load.
+ */
+ public static NLClassifier createFromFile(Context context, String pathToModel)
+ throws IOException {
+ return createFromFileAndOptions(context, pathToModel, NLClassifierOptions.builder().build());
+ }
+
+ /**
+ * Create {@link NLClassifier} from default {@link NLClassifierOptions}.
+ *
+ * @param modelFile The classification model {@link File} instance.
+ * @return {@link NLClassifier} instance.
+ * @throws IOException If model file fails to load.
+ */
+ public static NLClassifier createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build());
+ }
+
+ /**
+ * Create {@link NLClassifier} from {@link NLClassifierOptions}.
+ *
+ * @param context Android context
+ * @param pathToModel Path to the classification model relative to asset dir.
+ * @param options Configurations for the model.
+ * @return {@link NLClassifier} instance.
+ * @throws IOException If model file fails to load.
+ */
+ public static NLClassifier createFromFileAndOptions(
+ Context context, String pathToModel, NLClassifierOptions options) throws IOException {
+ return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, pathToModel), options);
+ }
+
+ /**
+ * Create {@link NLClassifier} from {@link NLClassifierOptions}.
+ *
+ * @param modelFile The classification model {@link File} instance.
+ * @param options Configurations for the model.
+ * @return {@link NLClassifier} instance.
+ * @throws IOException If model file fails to load.
+ */
+ public static NLClassifier createFromFileAndOptions(
+ File modelFile, final NLClassifierOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new NLClassifier(
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithFileDescriptor(options, descriptor.getFd());
+ }
+ },
+ NL_CLASSIFIER_NATIVE_LIBNAME));
+ }
+ }
+
+ /**
+ * Create {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @param options Configurations for the model
+ * @return {@link NLClassifier} instance
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static NLClassifier createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final NLClassifierOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new NLClassifier(
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(options, modelBuffer);
+ }
+ },
+ NL_CLASSIFIER_NATIVE_LIBNAME));
+ }
+
+ /**
+ * Perform classification on a string input, returns classified {@link Category}s.
+ *
+ * @param text input text to the model.
+ * @return A list of Category results.
+ */
+ public List<Category> classify(String text) {
+ return classifyNative(getNativeHandle(), text);
+ }
+
+ private static native long initJniWithByteBuffer(
+ NLClassifierOptions options, ByteBuffer modelBuffer);
+
+ private static native long initJniWithFileDescriptor(NLClassifierOptions options, int fd);
+
+ private static native List<Category> classifyNative(long nativeHandle, String text);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
+ }
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD
new file mode 100644
index 00000000..3dad1422
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD
@@ -0,0 +1,33 @@
+load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+filegroup(
+ name = "bert_question_answerer_src",
+ srcs = glob(["**/*.java"]),
+)
+
+android_library(
+ name = "bert_question_answerer",
+ srcs = glob(["*.java"]),
+ javacopts = JAVACOPTS,
+ deps = [
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
+ "//tensorflow_lite_support/java/src/native/task/text/qa:bert_question_answerer_native",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
+ ],
+)
+
+# AAR target for OSS release.
+#
+# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa:bert-question-answerer
+aar_with_jni(
+ name = "bert-question-answerer",
+ android_library = ":bert_question_answerer",
+)
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
new file mode 100644
index 00000000..76f562ef
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
@@ -0,0 +1,195 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.text.qa;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.List;
+import org.tensorflow.lite.task.core.BaseTaskApi;
+import org.tensorflow.lite.task.core.TaskJniUtils;
+import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
+import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
+
+/** Task API for BertQA models. */
+public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer {
+ private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
+
+ private BertQuestionAnswerer(long nativeHandle) {
+ super(nativeHandle);
+ }
+
+ /**
+ * Generic API to create the QuestionAnswerer for bert models with metadata populated. The API
+ * expects a Bert based TFLite model with metadata containing the following information:
+ *
+ * <ul>
+ * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be
+ * used for a <a
+ * href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a>
+ * model, Sentencepiece Tokenizer Tokenizer can be used for an <a
+ * href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a>
+ * model.
+ * <li>3 input tensors with names "ids", "mask" and "segment_ids".
+ * <li>2 output tensors with names "end_logits" and "start_logits".
+ * </ul>
+ *
+ * @param context android context
+ * @param pathToModel file path to the model with metadata. Note: The model should not be
+ * compressed
+ * @return {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load.
+ */
+ public static BertQuestionAnswerer createFromFile(Context context, String pathToModel)
+ throws IOException {
+ return new BertQuestionAnswerer(
+ TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
+ context,
+ new MultipleBuffersHandleProvider() {
+ @Override
+ public long createHandle(ByteBuffer... buffers) {
+ return BertQuestionAnswerer.initJniWithModelWithMetadataByteBuffers(buffers);
+ }
+ },
+ BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
+ pathToModel));
+ }
+
+ /**
+ * Generic API to create the QuestionAnswerer for bert models with metadata populated. The API
+ * expects a Bert based TFLite model with metadata containing the following information:
+ *
+ * <ul>
+ * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be
+ * used for a <a
+ * href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a>
+ * model, Sentencepiece Tokenizer Tokenizer can be used for an <a
+ * href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a>
+ * model.
+ * <li>3 input tensors with names "ids", "mask" and "segment_ids".
+ * <li>2 output tensors with names "end_logits" and "start_logits".
+ * </ul>
+ *
+ * @param modelFile {@link File} object of the model
+ * @return {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load.
+ */
+ public static BertQuestionAnswerer createFromFile(File modelFile)
+ throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new BertQuestionAnswerer(
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithFileDescriptor(descriptor.getFd());
+ }
+ },
+ BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
+ }
+ }
+
+ /**
+ * Creates the API instance with a bert model and vocabulary file.
+ *
+ * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
+ *
+ * @param context android context
+ * @param pathToModel file path to the bert model. Note: The model should not be compressed
+ * @param pathToVocab file path to the vocabulary file. Note: The file should not be compressed
+ * @return {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load.
+ */
+ public static BertQuestionAnswerer createBertQuestionAnswererFromFile(
+ Context context, String pathToModel, String pathToVocab) throws IOException {
+ return new BertQuestionAnswerer(
+ TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
+ context,
+ new MultipleBuffersHandleProvider() {
+ @Override
+ public long createHandle(ByteBuffer... buffers) {
+ return BertQuestionAnswerer.initJniWithBertByteBuffers(buffers);
+ }
+ },
+ BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
+ pathToModel,
+ pathToVocab));
+ }
+
+ /**
+ * Creates the API instance with an albert model and sentence piece model file.
+ *
+ * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
+ *
+ * @param context android context
+ * @param pathToModel file path to the albert model. Note: The model should not be compressed
+ * @param pathToSentencePieceModel file path to the sentence piece model file. Note: The model
+ * should not be compressed
+ * @return {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load.
+ */
+ public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(
+ Context context, String pathToModel, String pathToSentencePieceModel) throws IOException {
+ return new BertQuestionAnswerer(
+ TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
+ context,
+ new MultipleBuffersHandleProvider() {
+ @Override
+ public long createHandle(ByteBuffer... buffers) {
+ return BertQuestionAnswerer.initJniWithAlbertByteBuffers(buffers);
+ }
+ },
+ BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
+ pathToModel,
+ pathToSentencePieceModel));
+ }
+
+ @Override
+ public List<QaAnswer> answer(String context, String question) {
+ checkNotClosed();
+ return answerNative(getNativeHandle(), context, question);
+ }
+
+ // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
+ private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
+
+ // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file
+ // buffer.
+ private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers);
+
+ // modelBuffers[0] is tflite model file buffer with metadata to specify which tokenizer to use.
+ private static native long initJniWithModelWithMetadataByteBuffers(ByteBuffer... modelBuffers);
+
+ private static native long initJniWithFileDescriptor(int fd);
+
+ private static native List<QaAnswer> answerNative(
+ long nativeHandle, String context, String question);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
+ }
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
new file mode 100644
index 00000000..4259a697
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java
@@ -0,0 +1,58 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.text.qa;
+
+import org.tensorflow.lite.annotations.UsedByReflection;
+
+/**
+ * Answers to {@link QuestionAnswerer}. Contains information about the answer and its relative
+ * position information to the context.
+ */
+public class QaAnswer {
+ public Pos pos;
+ public String text;
+
+ @UsedByReflection("bert_question_answerer_jni.cc")
+ public QaAnswer(String text, Pos pos) {
+ this.text = text;
+ this.pos = pos;
+ }
+
+ public QaAnswer(String text, int start, int end, float logit) {
+ this(text, new Pos(start, end, logit));
+ }
+
+ /**
+ * Position information of the answer relative to context. It is sortable in descending order
+ * based on logit.
+ */
+ public static class Pos implements Comparable<Pos> {
+ public int start;
+ public int end;
+ public float logit;
+
+ public Pos(int start, int end, float logit) {
+ this.start = start;
+ this.end = end;
+ this.logit = logit;
+ }
+
+ @Override
+ public int compareTo(Pos other) {
+ return Float.compare(other.logit, this.logit);
+ }
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
new file mode 100644
index 00000000..8df6d379
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java
@@ -0,0 +1,32 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.text.qa;
+
+import java.util.List;
+
+/** API to answer questions based on context. */
+public interface QuestionAnswerer {
+
+ /**
+ * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be
+ * empty if no answer was found from the given context.
+ *
+ * @param context context the question bases on
+ * @param question question to ask
+ * @return a list of possible answers in {@link QaAnswer}
+ */
+ List<QaAnswer> answer(String context, String question);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml
new file mode 100644
index 00000000..e77a0734
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="org.tensorflow.lite.task.vision">
+ <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/>
+</manifest>
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD
new file mode 100644
index 00000000..661a7669
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD
@@ -0,0 +1,41 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "AndroidManifest.xml",
+])
+
+android_library(
+ name = "task_library_vision",
+ srcs = [
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier:image_classifier_src",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector:object_detector_src",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter:image_segmenter_src",
+ ],
+ # TODO(b/163039980): Use JAVACOPTS in TF. "-Xep:RemoveUnusedImports:ERROR" wierdly break the build.
+ javacopts = ["-source 7 -target 7"],
+ manifest = "AndroidManifest.xml",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow_lite_support/java:tensorflowlite_support_java",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
+ "//tensorflow_lite_support/java/src/native/task/vision:task_vision_native",
+ "@com_google_auto_value",
+ "@maven//:androidx_annotation_annotation",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
+ ],
+)
+
+# AAR target for OSS release.
+#
+# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision:task-library-vision
+aar_with_jni(
+ name = "task-library-vision",
+ android_library = ":task_library_vision",
+)
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml
new file mode 100644
index 00000000..ce07182e
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="org.tensorflow.lite.task.vision.classifier">
+ <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/>
+</manifest>
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD
new file mode 100644
index 00000000..c6a70a08
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD
@@ -0,0 +1,40 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
+load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "AndroidManifest.xml",
+])
+
+filegroup(
+ name = "image_classifier_src",
+ srcs = glob(["**/*.java"]),
+)
+
+android_library(
+ name = "image_classifier",
+ srcs = glob(["*.java"]),
+ javacopts = JAVACOPTS,
+ manifest = "AndroidManifest.xml",
+ deps = [
+ "//tensorflow_lite_support/java:tensorflowlite_support_java",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
+ "//tensorflow_lite_support/java/src/native/task/vision/classifier:image_classifier_native",
+ "@com_google_auto_value",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
+ ],
+)
+
+# AAR target for OSS release.
+#
+# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier:image-classifier
+aar_with_jni(
+ name = "image-classifier",
+ android_library = ":image_classifier",
+)
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
new file mode 100644
index 00000000..d33f0fbb
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java
@@ -0,0 +1,46 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.vision.classifier;
+
+import com.google.auto.value.AutoValue;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.tensorflow.lite.annotations.UsedByReflection;
+import org.tensorflow.lite.support.label.Category;
+
+/**
+ * The classification results of one head in a multihead (a.k.a. multi-output) {@link
+ * ImageClassifier}. A multihead {@link ImageClassifier} can perform classification for multiple
+ * purposes, such as a fine grained classifier to describe apparel items (e.g. color, material,
+ * type, etc.).
+ */
+@AutoValue
+@UsedByReflection("image_classifier_jni.cc")
+public abstract class Classifications {
+
+ @UsedByReflection("image_classifier_jni.cc")
+ static Classifications create(List<Category> categories, int headIndex) {
+ return new AutoValue_Classifications(
+ Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex);
+ }
+
+ // Same reason for not using ImmutableList as stated in
+ // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}.
+ public abstract List<Category> getCategories();
+
+ public abstract int getHeadIndex();
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
new file mode 100644
index 00000000..46f6754e
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java
@@ -0,0 +1,453 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.vision.classifier;
+
+import android.content.Context;
+import android.graphics.Rect;
+import android.os.ParcelFileDescriptor;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.annotations.UsedByReflection;
+import org.tensorflow.lite.support.image.TensorImage;
+import org.tensorflow.lite.task.core.BaseTaskApi;
+import org.tensorflow.lite.task.core.TaskJniUtils;
+import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
+import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
+import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
+
+/**
+ * Performs classification on images.
+ *
+ * <p>The API expects a TFLite model with optional, but strongly recommended, <a
+ * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
+ *
+ * <p>The API supports models with one image input tensor and one classification output tensor. To
+ * be more specific, here are the requirements.
+ *
+ * <ul>
+ * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
+ * <ul>
+ * <li>image input of size {@code [batch x height x width x channels]}.
+ * <li>batch inference is not supported ({@code batch} is required to be 1).
+ * <li>only RGB inputs are supported ({@code channels} is required to be 3).
+ * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
+ * to the metadata for input normalization.
+ * </ul>
+ * <li>Output score tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
+ * <ul>
+ * <li>with {@code N} classes of either 2 or 4 dimensions, such as {@code [1 x N]} or {@code
+ * [1 x 1 x 1 x N]}
+ * <li>the label file is required to be packed to the metadata. See the <a
+ * href="https://www.tensorflow.org/lite/convert/metadata#label_output">example of
+ * creating metadata for an image classifier</a>. If no label files are packed, it will
+ * use index as label in the result.
+ * </ul>
+ * </ul>
+ *
+ * <p>An example of such model can be found on <a
+ * href="https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1">TensorFlow
+ * Hub.</a>.
+ */
+public final class ImageClassifier extends BaseTaskApi {
+
+ private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni";
+ private static final int OPTIONAL_FD_LENGTH = -1;
+ private static final int OPTIONAL_FD_OFFSET = -1;
+
+ /**
+ * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
+ *
+ * @param modelPath path of the classification model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native
+ * code
+ */
+ public static ImageClassifier createFromFile(Context context, String modelPath)
+ throws IOException {
+ return createFromFileAndOptions(context, modelPath, ImageClassifierOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}.
+ *
+ * @param modelFile the classification model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native
+ * code
+ */
+ public static ImageClassifier createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link
+ * ImageClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native
+ * code
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) {
+ return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}.
+ *
+ * @param modelPath path of the classification model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native
+ * code
+ */
+ public static ImageClassifier createFromFileAndOptions(
+ Context context, String modelPath, ImageClassifierOptions options) throws IOException {
+ return new ImageClassifier(
+ TaskJniUtils.createHandleFromFdAndOptions(
+ context,
+ new FdAndOptionsHandleProvider<ImageClassifierOptions>() {
+ @Override
+ public long createHandle(
+ int fileDescriptor,
+ long fileDescriptorLength,
+ long fileDescriptorOffset,
+ ImageClassifierOptions options) {
+ return initJniWithModelFdAndOptions(
+ fileDescriptor, fileDescriptorLength, fileDescriptorOffset, options);
+ }
+ },
+ IMAGE_CLASSIFIER_NATIVE_LIB,
+ modelPath,
+ options));
+ }
+
+ /**
+ * Creates an {@link ImageClassifier} instance.
+ *
+ * @param modelFile the classification model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native
+ * code
+ */
+ public static ImageClassifier createFromFileAndOptions(
+ File modelFile, final ImageClassifierOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new ImageClassifier(
+ TaskJniUtils.createHandleFromLibrary(
+ new TaskJniUtils.EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithModelFdAndOptions(
+ descriptor.getFd(),
+ /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
+ /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
+ options);
+ }
+ },
+ IMAGE_CLASSIFIER_NATIVE_LIB));
+ }
+ }
+
+ /**
+ * Creates an {@link ImageClassifier} instance with a model buffer and {@link
+ * ImageClassifierOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native
+ * code
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static ImageClassifier createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final ImageClassifierOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new ImageClassifier(
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(modelBuffer, options);
+ }
+ },
+ IMAGE_CLASSIFIER_NATIVE_LIB));
+ }
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++
+ */
+ private ImageClassifier(long nativeHandle) {
+ super(nativeHandle);
+ }
+
+ /** Options for setting up an ImageClassifier. */
+ @UsedByReflection("image_classifier_jni.cc")
+ public static class ImageClassifierOptions {
+ // Not using AutoValue for this class because scoreThreshold cannot have default value
+ // (otherwise, the default value would override the one in the model metadata) and `Optional` is
+ // not an option here, because
+ // 1. java.util.Optional require Java 8 while we need to support Java 7.
+ // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
+ // comments for labelAllowList.
+ private final String displayNamesLocale;
+ private final int maxResults;
+ private final float scoreThreshold;
+ private final boolean isScoreThresholdSet;
+ // As an open source project, we've been trying avoiding depending on common java libraries,
+ // such as Guava, because it may introduce conflicts with clients who also happen to use those
+ // libraries. Therefore, instead of using ImmutableList here, we convert the List into
+ // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
+ // vulnerable.
+ private final List<String> labelAllowList;
+ private final List<String> labelDenyList;
+ private final int numThreads;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** A builder that helps to configure an instance of ImageClassifierOptions. */
+ public static class Builder {
+ private String displayNamesLocale = "en";
+ private int maxResults = -1;
+ private float scoreThreshold;
+ private boolean isScoreThresholdSet = false;
+ private List<String> labelAllowList = new ArrayList<>();
+ private List<String> labelDenyList = new ArrayList<>();
+ private int numThreads = -1;
+
+ private Builder() {}
+
+ /**
+ * Sets the locale to use for display names specified through the TFLite Model Metadata, if
+ * any.
+ *
+ * <p>Defaults to English({@code "en"}). See the <a
+ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
+ * Metadata schema file.</a> for the accepted pattern of locale.
+ */
+ public Builder setDisplayNamesLocale(String displayNamesLocale) {
+ this.displayNamesLocale = displayNamesLocale;
+ return this;
+ }
+
+ /**
+ * Sets the maximum number of top scored results to return.
+ *
+ * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned.
+ * Defaults to -1.
+ *
+ * @throws IllegalArgumentException if maxResults is 0.
+ */
+ public Builder setMaxResults(int maxResults) {
+ if (maxResults == 0) {
+ throw new IllegalArgumentException("maxResults cannot be 0.");
+ }
+ this.maxResults = maxResults;
+ return this;
+ }
+
+ /**
+ * Sets the score threshold in [0,1).
+ *
+ * <p>It overrides the one provided in the model metadata (if any). Results below this value
+ * are rejected.
+ */
+ public Builder setScoreThreshold(float scoreThreshold) {
+ this.scoreThreshold = scoreThreshold;
+ isScoreThresholdSet = true;
+ return this;
+ }
+
+ /**
+ * Sets the optional allowlist of labels.
+ *
+ * <p>If non-empty, classifications whose label is not in this set will be filtered out.
+ * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList.
+ */
+ public Builder setLabelAllowList(List<String> labelAllowList) {
+ this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
+ return this;
+ }
+
+ /**
+ * Sets the optional denylist of labels.
+ *
+ * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate
+ * or unknown labels are ignored. Mutually exclusive with labelAllowList.
+ */
+ public Builder setLabelDenyList(List<String> labelDenyList) {
+ this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
+ return this;
+ }
+
+ /**
+ * Sets the number of threads to be used for TFLite ops that support multi-threading when
+ * running inference with CPU. Defaults to -1.
+ *
+ * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
+ * effect to let TFLite runtime set the value.
+ */
+ public Builder setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ public ImageClassifierOptions build() {
+ return new ImageClassifierOptions(this);
+ }
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public String getDisplayNamesLocale() {
+ return displayNamesLocale;
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public int getMaxResults() {
+ return maxResults;
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public float getScoreThreshold() {
+ return scoreThreshold;
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public boolean getIsScoreThresholdSet() {
+ return isScoreThresholdSet;
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public List<String> getLabelAllowList() {
+ return new ArrayList<>(labelAllowList);
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public List<String> getLabelDenyList() {
+ return new ArrayList<>(labelDenyList);
+ }
+
+ @UsedByReflection("image_classifier_jni.cc")
+ public int getNumThreads() {
+ return numThreads;
+ }
+
+ private ImageClassifierOptions(Builder builder) {
+ displayNamesLocale = builder.displayNamesLocale;
+ maxResults = builder.maxResults;
+ scoreThreshold = builder.scoreThreshold;
+ isScoreThresholdSet = builder.isScoreThresholdSet;
+ labelAllowList = builder.labelAllowList;
+ labelDenyList = builder.labelDenyList;
+ numThreads = builder.numThreads;
+ }
+ }
+
+ /**
+ * Performs actual classification on the provided image.
+ *
+ * @param image a {@link TensorImage} object that represents an RGB image
+ * @throws AssertionError if error occurs when classifying the image from the native code
+ */
+ public List<Classifications> classify(TensorImage image) {
+ return classify(image, ImageProcessingOptions.builder().build());
+ }
+
+ /**
+ * Performs actual classification on the provided image with {@link ImageProcessingOptions}.
+ *
+ * <p>{@link ImageClassifier} supports the following options:
+ *
+ * <ul>
+ * <li>Region of interest (ROI) (through {@link ImageProcessingOptions#Builder#setRoi}). It
+ * defaults to the entire image.
+ * <li>image rotation (through {@link ImageProcessingOptions#Builder#setOrientation}). It
+ * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}.
+ * </ul>
+ *
+ * @param image a {@link TensorImage} object that represents an RGB image
+ * @throws AssertionError if error occurs when classifying the image from the native code
+ */
+ public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) {
+ checkNotClosed();
+
+ // image_classifier_jni.cc expects an uint8 image. Convert image of other types into uint8.
+ TensorImage imageUint8 =
+ image.getDataType() == DataType.UINT8
+ ? image
+ : TensorImage.createFrom(image, DataType.UINT8);
+
+ Rect roi =
+ options.getRoi().isEmpty()
+ ? new Rect(0, 0, imageUint8.getWidth(), imageUint8.getHeight())
+ : options.getRoi();
+
+ return classifyNative(
+ getNativeHandle(),
+ imageUint8.getBuffer(),
+ imageUint8.getWidth(),
+ imageUint8.getHeight(),
+ new int[] {roi.left, roi.top, roi.width(), roi.height()},
+ options.getOrientation().getValue());
+ }
+
+ private static native long initJniWithModelFdAndOptions(
+ int fileDescriptor,
+ long fileDescriptorLength,
+ long fileDescriptorOffset,
+ ImageClassifierOptions options);
+
+ private static native long initJniWithByteBuffer(
+ ByteBuffer modelBuffer, ImageClassifierOptions options);
+
+ /**
+ * The native method to classify an image with the ROI and orientation.
+ *
+ * @param roi the ROI of the input image, an array representing the bounding box as {left, top,
+ * width, height}
+ * @param orientation the integer value corresponding to {@link
+ * ImageProcessingOptions#Orientation}
+ */
+ private static native List<Classifications> classifyNative(
+ long nativeHandle, ByteBuffer image, int width, int height, int[] roi, int orientation);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
+ }
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml
new file mode 100644
index 00000000..5fefccd0
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="org.tensorflow.lite.task.vision.detector">
+ <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/>
+</manifest>
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD
new file mode 100644
index 00000000..d0d541ab
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD
@@ -0,0 +1,40 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
+load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "AndroidManifest.xml",
+])
+
+filegroup(
+ name = "object_detector_src",
+ srcs = glob(["**/*.java"]),
+)
+
+android_library(
+ name = "object_detector",
+ srcs = glob(["*.java"]),
+ javacopts = JAVACOPTS,
+ manifest = "AndroidManifest.xml",
+ deps = [
+ "//tensorflow_lite_support/java:tensorflowlite_support_java",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
+ "//tensorflow_lite_support/java/src/native/task/vision/detector:object_detector_native",
+ "@com_google_auto_value",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
+ ],
+)
+
+# AAR target for OSS release.
+#
+# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector:object-detector
+aar_with_jni(
+ name = "object-detector",
+ android_library = ":object_detector",
+)
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
new file mode 100644
index 00000000..007e032d
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java
@@ -0,0 +1,42 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.vision.detector;
+
+import android.graphics.RectF;
+import com.google.auto.value.AutoValue;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.tensorflow.lite.annotations.UsedByReflection;
+import org.tensorflow.lite.support.label.Category;
+
+/** Represents one detected object in the results of a {@link ObjectDetector}. */
+@AutoValue
+@UsedByReflection("object_detection_jni.cc")
+public abstract class Detection {
+
+ @UsedByReflection("object_detection_jni.cc")
+ public static Detection create(RectF boundingBox, List<Category> categories) {
+ return new AutoValue_Detection(
+ new RectF(boundingBox), Collections.unmodifiableList(new ArrayList<Category>(categories)));
+ }
+
+ public abstract RectF getBoundingBox();
+
+ // Same reason for not using ImmutableList as stated in
+ // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}.
+ public abstract List<Category> getCategories();
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
new file mode 100644
index 00000000..75bc9836
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java
@@ -0,0 +1,452 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.vision.detector;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.annotations.UsedByReflection;
+import org.tensorflow.lite.support.image.TensorImage;
+import org.tensorflow.lite.task.core.BaseTaskApi;
+import org.tensorflow.lite.task.core.TaskJniUtils;
+import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
+import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
+import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
+
+/**
+ * Performs object detection on images.
+ *
+ * <p>The API expects a TFLite model with <a
+ * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
+ *
+ * <p>The API supports models with one image input tensor and four output tensors. To be more
+ * specific, here are the requirements.
+ *
+ * <ul>
+ * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
+ * <ul>
+ * <li>image input of size {@code [batch x height x width x channels]}.
+ * <li>batch inference is not supported ({@code batch} is required to be 1).
+ * <li>only RGB inputs are supported ({@code channels} is required to be 3).
+ * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
+ * to the metadata for input normalization.
+ * </ul>
+ * <li>Output tensors must be the 4 outputs of a {@code DetectionPostProcess} op, i.e:
+ * <ul>
+ * <li>Location tensor ({@code kTfLiteFloat32}):
+ * <ul>
+ * <li>tensor of size {@code [1 x num_results x 4]}, the inner array representing
+ * bounding boxes in the form [top, left, right, bottom].
+ * <li>{@code BoundingBoxProperties} are required to be attached to the metadata and
+ * must specify {@code type=BOUNDARIES} and {@code coordinate_type=RATIO}.
+ * </ul>
+ * <li>Classes tensor ({@code kTfLiteFloat32}):
+ * <ul>
+ * <li>tensor of size {@code [1 x num_results]}, each value representing the integer
+ * index of a class.
+ * <li>if label maps are attached to the metadata as {@code TENSOR_VALUE_LABELS}
+ * associated files, they are used to convert the tensor values into labels.
+ * </ul>
+ * <li>scores tensor ({@code kTfLiteFloat32}):
+ * <ul>
+ * <li>tensor of size {@code [1 x num_results]}, each value representing the score of
+ * the detected object.
+ * </ul>
+ * <li>Number of detection tensor ({@code kTfLiteFloat32}):
+ * <ul>
+ * <li>integer num_results as a tensor of size {@code [1]}.
+ * </ul>
+ * </ul>
+ * </ul>
+ *
+ * <p>An example of such model can be found on <a
+ * href="https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1">TensorFlow
+ * Hub.</a>.
+ */
+public final class ObjectDetector extends BaseTaskApi {
+
+ private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni";
+ private static final int OPTIONAL_FD_LENGTH = -1;
+ private static final int OPTIONAL_FD_OFFSET = -1;
+
+ /**
+ * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
+ *
+ * @param modelPath path to the detection model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native
+ * code
+ */
+ public static ObjectDetector createFromFile(Context context, String modelPath)
+ throws IOException {
+ return createFromFileAndOptions(context, modelPath, ObjectDetectorOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}.
+ *
+ * @param modelFile the detection model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native
+ * code
+ */
+ public static ObjectDetector createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link
+ * ObjectDetectorOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native
+ * code
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) {
+ return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
+ *
+ * @param modelPath path to the detection model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native
+ * code
+ */
+ public static ObjectDetector createFromFileAndOptions(
+ Context context, String modelPath, ObjectDetectorOptions options) throws IOException {
+ return new ObjectDetector(
+ TaskJniUtils.createHandleFromFdAndOptions(
+ context,
+ new FdAndOptionsHandleProvider<ObjectDetectorOptions>() {
+ @Override
+ public long createHandle(
+ int fileDescriptor,
+ long fileDescriptorLength,
+ long fileDescriptorOffset,
+ ObjectDetectorOptions options) {
+ return initJniWithModelFdAndOptions(
+ fileDescriptor, fileDescriptorLength, fileDescriptorOffset, options);
+ }
+ },
+ OBJECT_DETECTOR_NATIVE_LIB,
+ modelPath,
+ options));
+ }
+
+ /**
+ * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}.
+ *
+ * @param modelFile the detection model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native
+ * code
+ */
+ public static ObjectDetector createFromFileAndOptions(
+ File modelFile, final ObjectDetectorOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new ObjectDetector(
+ TaskJniUtils.createHandleFromLibrary(
+ new TaskJniUtils.EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithModelFdAndOptions(
+ descriptor.getFd(),
+ /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
+ /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
+ options);
+ }
+ },
+ OBJECT_DETECTOR_NATIVE_LIB));
+ }
+ }
+
+ /**
+ * Creates an {@link ObjectDetector} instance with a model buffer and {@link
+ * ObjectDetectorOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native
+ * code
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static ObjectDetector createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final ObjectDetectorOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new ObjectDetector(
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(modelBuffer, options);
+ }
+ },
+ OBJECT_DETECTOR_NATIVE_LIB));
+ }
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++
+ */
+ private ObjectDetector(long nativeHandle) {
+ super(nativeHandle);
+ }
+
+ /** Options for setting up an ObjectDetector. */
+ @UsedByReflection("object_detector_jni.cc")
+ public static class ObjectDetectorOptions {
+ // Not using AutoValue for this class because scoreThreshold cannot have default value
+ // (otherwise, the default value would override the one in the model metadata) and `Optional` is
+ // not an option here, because
+ // 1. java.util.Optional require Java 8 while we need to support Java 7.
+ // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
+ // comments for labelAllowList.
+ private final String displayNamesLocale;
+ private final int maxResults;
+ private final float scoreThreshold;
+ private final boolean isScoreThresholdSet;
+ // As an open source project, we've been trying avoiding depending on common java libraries,
+ // such as Guava, because it may introduce conflicts with clients who also happen to use those
+ // libraries. Therefore, instead of using ImmutableList here, we convert the List into
+ // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less
+ // vulnerable.
+ private final List<String> labelAllowList;
+ private final List<String> labelDenyList;
+ private final int numThreads;
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** A builder that helps to configure an instance of ObjectDetectorOptions. */
+ public static class Builder {
+ private String displayNamesLocale = "en";
+ private int maxResults = -1;
+ private float scoreThreshold;
+ private boolean isScoreThresholdSet = false;
+ private List<String> labelAllowList = new ArrayList<>();
+ private List<String> labelDenyList = new ArrayList<>();
+ private int numThreads = -1;
+
+ private Builder() {}
+
+ /**
+ * Sets the locale to use for display names specified through the TFLite Model Metadata, if
+ * any.
+ *
+ * <p>Defaults to English({@code "en"}). See the <a
+ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
+ * Metadata schema file.</a> for the accepted pattern of locale.
+ */
+ public Builder setDisplayNamesLocale(String displayNamesLocale) {
+ this.displayNamesLocale = displayNamesLocale;
+ return this;
+ }
+
+ /**
+ * Sets the maximum number of top-scored detection results to return.
+ *
+ * <p>If < 0, all available results will be returned. If 0, an invalid argument error is
+ * returned. Note that models may intrinsically be limited to returning a maximum number of
+ * results N: if the provided value here is above N, only N results will be returned. Defaults
+ * to -1.
+ *
+ * @throws IllegalArgumentException if maxResults is 0.
+ */
+ public Builder setMaxResults(int maxResults) {
+ if (maxResults == 0) {
+ throw new IllegalArgumentException("maxResults cannot be 0.");
+ }
+ this.maxResults = maxResults;
+ return this;
+ }
+
+ /**
+ * Sets the score threshold that overrides the one provided in the model metadata (if any).
+ * Results below this value are rejected.
+ */
+ public Builder setScoreThreshold(float scoreThreshold) {
+ this.scoreThreshold = scoreThreshold;
+ this.isScoreThresholdSet = true;
+ return this;
+ }
+
+ /**
+ * Sets the optional allow list of labels.
+ *
+ * <p>If non-empty, detection results whose label is not in this set will be filtered out.
+ * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelDenyList}. It
+ * will cause {@link AssertionError} when calling {@link #createFromFileAndOptions}, if both
+ * {@code labelDenyList} and {@code labelAllowList} are set.
+ */
+ public Builder setLabelAllowList(List<String> labelAllowList) {
+ this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList));
+ return this;
+ }
+
+ /**
+ * Sets the optional deny list of labels.
+ *
+ * <p>If non-empty, detection results whose label is in this set will be filtered out.
+ * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelAllowList}. It
+ * will cause {@link AssertionError} when calling {@link #createFromFileAndOptions}, if both
+ * {@code labelDenyList} and {@code labelAllowList} are set.
+ */
+ public Builder setLabelDenyList(List<String> labelDenyList) {
+ this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList));
+ return this;
+ }
+
+ /**
+ * Sets the number of threads to be used for TFLite ops that support multi-threading when
+ * running inference with CPU. Defaults to -1.
+ *
+ * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
+ * effect to let TFLite runtime set the value.
+ */
+ public Builder setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ public ObjectDetectorOptions build() {
+ return new ObjectDetectorOptions(this);
+ }
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public String getDisplayNamesLocale() {
+ return displayNamesLocale;
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public int getMaxResults() {
+ return maxResults;
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public float getScoreThreshold() {
+ return scoreThreshold;
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public boolean getIsScoreThresholdSet() {
+ return isScoreThresholdSet;
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public List<String> getLabelAllowList() {
+ return new ArrayList<>(labelAllowList);
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public List<String> getLabelDenyList() {
+ return new ArrayList<>(labelDenyList);
+ }
+
+ @UsedByReflection("object_detector_jni.cc")
+ public int getNumThreads() {
+ return numThreads;
+ }
+
+ private ObjectDetectorOptions(Builder builder) {
+ displayNamesLocale = builder.displayNamesLocale;
+ maxResults = builder.maxResults;
+ scoreThreshold = builder.scoreThreshold;
+ isScoreThresholdSet = builder.isScoreThresholdSet;
+ labelAllowList = builder.labelAllowList;
+ labelDenyList = builder.labelDenyList;
+ numThreads = builder.numThreads;
+ }
+ }
+
+ /**
+ * Performs actual detection on the provided image.
+ *
+ * @param image a {@link TensorImage} object that represents a RGB image
+ * @throws AssertionError if error occurs when processing the image from the native code
+ */
+ public List<Detection> detect(TensorImage image) {
+ return detect(image, ImageProcessingOptions.builder().build());
+ }
+
+ /**
+ * Performs actual detection on the provided image.
+ *
+ * @param image a {@link TensorImage} object that represents a RGB image
+ * @param options {@link ObjectDetector} only supports image rotation (through {@link
+ * ImageProcessingOptions#Builder#setOrientation}) currently. The orientation of an image
+ * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}.
+ * @throws AssertionError if error occurs when processing the image from the native code
+ */
+ public List<Detection> detect(TensorImage image, ImageProcessingOptions options) {
+ checkNotClosed();
+
+ // object_detector_jni.cc expects an uint8 image. Convert image of other types into uint8.
+ TensorImage imageUint8 =
+ image.getDataType() == DataType.UINT8
+ ? image
+ : TensorImage.createFrom(image, DataType.UINT8);
+ return detectNative(
+ getNativeHandle(),
+ imageUint8.getBuffer(),
+ imageUint8.getWidth(),
+ imageUint8.getHeight(),
+ options.getOrientation().getValue());
+ }
+
+ private static native long initJniWithModelFdAndOptions(
+ int fileDescriptor,
+ long fileDescriptorLength,
+ long fileDescriptorOffset,
+ ObjectDetectorOptions options);
+
+ private static native long initJniWithByteBuffer(
+ ByteBuffer modelBuffer, ObjectDetectorOptions options);
+
+ private static native List<Detection> detectNative(
+ long nativeHandle, ByteBuffer image, int width, int height, int orientation);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
+ }
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml
new file mode 100644
index 00000000..991d4816
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml
@@ -0,0 +1,5 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="org.tensorflow.lite.task.vision.segmenter">
+ <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/>
+</manifest>
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD
new file mode 100644
index 00000000..506d5bfb
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD
@@ -0,0 +1,41 @@
+load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni")
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "AndroidManifest.xml",
+])
+
+filegroup(
+ name = "image_segmenter_src",
+ srcs = glob(["**/*.java"]),
+)
+
+android_library(
+ name = "image_segmenter",
+ srcs = glob(["*.java"]),
+ # TODO(b/163039980): Use JAVACOPTS in TF. "-Xep:RemoveUnusedImports:ERROR" wierdly break the build.
+ javacopts = ["-source 7 -target 7"],
+ manifest = "AndroidManifest.xml",
+ deps = [
+ "//tensorflow_lite_support/java:tensorflowlite_support_java",
+ "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
+ "//tensorflow_lite_support/java/src/native/task/vision/segmenter:image_segmenter_native",
+ "@com_google_auto_value",
+ "@maven//:androidx_annotation_annotation",
+ "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
+ ],
+)
+
+# AAR target for OSS release.
+#
+# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter:image-segmenter
+aar_with_jni(
+ name = "image-segmenter",
+ android_library = ":image_segmenter",
+)
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
new file mode 100644
index 00000000..09416d08
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java
@@ -0,0 +1,88 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.vision.segmenter;
+
+import android.graphics.Color;
+import android.os.Build;
+import androidx.annotation.RequiresApi;
+import com.google.auto.value.AutoValue;
+import org.tensorflow.lite.annotations.UsedByReflection;
+
+/** Represents a label associated with a color for display purposes. */
+@AutoValue
+@UsedByReflection("image_segmentation_jni.cc")
+public abstract class ColoredLabel {
+
+ /**
+ * Creates a {@link ColoredLabel} object with an ARGB color int.
+ *
+ * @param label the label string, as provided in the label map packed in the TFLite Model
+ * Metadata.
+ * @param displayName the display name of label, as configured through {@link
+ * ImageSegmenter#ImageSegmenterOptions#Builder#setDisplayNamesLocale}
+ * @param argb the color components for the label in ARGB. See <a
+ * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android
+ * Color ints.</a> for more details.
+ */
+ @UsedByReflection("image_segmentation_jni.cc")
+ public static ColoredLabel create(String label, String displayName, int argb) {
+ return new AutoValue_ColoredLabel(label, displayName, argb);
+ }
+
+ /**
+ * Creates a {@link ColoredLabel} object with a {@link Color} instance.
+ *
+ * @param label the label string, as provided in the label map packed in the TFLite Model
+ * Metadata.
+ * @param displayName the display name of label, as configured through {@link
+ * ImageSegmenter#ImageSegmenterOptions#Builder#setDisplayNamesLocale}
+ * @param color the color components for the label. The Color instatnce is supported on Android
+ * API level 26 and above. For API level lower than 26, use {@link #create(String, String,
+ * int)}. See <a
+ * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
+ * Color instances.</a> for more details.
+ */
+ @RequiresApi(Build.VERSION_CODES.O)
+ public static ColoredLabel create(String label, String displayName, Color color) {
+ return new AutoValue_ColoredLabel(label, displayName, color.toArgb());
+ }
+
+ public abstract String getlabel();
+
+ public abstract String getDisplayName();
+
+ /**
+ * Gets the ARGB int that represents the color.
+ *
+ * <p>See <a
+ * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android Color
+ * ints.</a> for more details.
+ */
+ public abstract int getArgb();
+
+ /**
+ * Gets the {@link Color} instance of the underlying color.
+ *
+ * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower than
+ * 26, use {@link #getArgb()}. See <a
+ * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android
+ * Color instances.</a> for more details.
+ */
+ @RequiresApi(Build.VERSION_CODES.O)
+ public Color getColor() {
+ return Color.valueOf(getArgb());
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
new file mode 100644
index 00000000..bd90790f
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java
@@ -0,0 +1,377 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.vision.segmenter;
+
+import android.content.Context;
+import android.content.res.AssetFileDescriptor;
+import android.os.ParcelFileDescriptor;
+import com.google.auto.value.AutoValue;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.tensorflow.lite.DataType;
+import org.tensorflow.lite.support.image.TensorImage;
+import org.tensorflow.lite.task.core.BaseTaskApi;
+import org.tensorflow.lite.task.core.TaskJniUtils;
+import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
+import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
+
+/**
+ * Performs segmentation on images.
+ *
+ * <p>The API expects a TFLite model with <a
+ * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
+ *
+ * <p>The API supports models with one image input tensor and one output tensor. To be more
+ * specific, here are the requirements.
+ *
+ * <ul>
+ * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
+ * <ul>
+ * <li>image input of size {@code [batch x height x width x channels]}.
+ * <li>batch inference is not supported ({@code batch} is required to be 1).
+ * <li>only RGB inputs are supported ({@code channels} is required to be 3).
+ * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
+ * to the metadata for input normalization.
+ * </ul>
+ * <li>Output image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
+ * <ul>
+ * <li>tensor of size {@code [batch x mask_height x mask_width x num_classes]}, where {@code
+ * batch} is required to be 1, {@code mask_width} and {@code mask_height} are the
+ * dimensions of the segmentation masks produced by the model, and {@code num_classes}
+ * is the number of classes supported by the model.
+ * <li>optional (but recommended) label map(s) can be attached as AssociatedFile-s with type
+ * TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if
+ * any) is used to fill the class name, i.e. {@link ColoredLabel#getClassName} of the
+ * results. The display name, i.e. {@link ColoredLabel#getDisplayName}, is filled from
+ * the AssociatedFile (if any) whose locale matches the `display_names_locale` field of
+ * the `ImageSegmenterOptions` used at creation time ("en" by default, i.e. English). If
+ * none of these are available, only the `index` field of the results will be filled.
+ * </ul>
+ * </ul>
+ *
+ * <p>An example of such model can be found on <a
+ * href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>.
+ */
+public final class ImageSegmenter extends BaseTaskApi {
+
+ private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni";
+ private static final int OPTIONAL_FD_LENGTH = -1;
+ private static final int OPTIONAL_FD_OFFSET = -1;
+
+ private final OutputType outputType;
+
+ /**
+ * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
+ *
+ * @param modelPath path of the segmentation model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
+ * code
+ */
+ public static ImageSegmenter createFromFile(Context context, String modelPath)
+ throws IOException {
+ return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}.
+ *
+ * @param modelFile the segmentation model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
+ * code
+ */
+ public static ImageSegmenter createFromFile(File modelFile) throws IOException {
+ return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link
+ * ImageSegmenterOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
+ * code
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) {
+ return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build());
+ }
+
+ /**
+ * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
+ *
+ * @param modelPath path of the segmentation model with metadata in the assets
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
+ * code
+ */
+ public static ImageSegmenter createFromFileAndOptions(
+ Context context, String modelPath, final ImageSegmenterOptions options) throws IOException {
+ try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) {
+ return createFromModelFdAndOptions(
+ /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ /*fileDescriptorLength=*/ assetFileDescriptor.getLength(),
+ /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(),
+ options);
+ }
+ }
+
+ /**
+ * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}.
+ *
+ * @param modelFile the segmentation model {@link File} instance
+ * @throws IOException if an I/O error occurs when loading the tflite model
+ * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
+ * code
+ */
+ public static ImageSegmenter createFromFileAndOptions(
+ File modelFile, final ImageSegmenterOptions options) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return createFromModelFdAndOptions(
+ /*fileDescriptor=*/ descriptor.getFd(),
+ /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
+ /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
+ options);
+ }
+ }
+
+ /**
+ * Creates an {@link ImageSegmenter} instance with a model buffer and {@link
+ * ImageSegmenterOptions}.
+ *
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model
+ * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native
+ * code
+ * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a
+ * {@link MappedByteBuffer}
+ */
+ public static ImageSegmenter createFromBufferAndOptions(
+ final ByteBuffer modelBuffer, final ImageSegmenterOptions options) {
+ if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) {
+ throw new IllegalArgumentException(
+ "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
+ }
+ return new ImageSegmenter(
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithByteBuffer(
+ modelBuffer,
+ options.getDisplayNamesLocale(),
+ options.getOutputType().getValue(),
+ options.getNumThreads());
+ }
+ },
+ IMAGE_SEGMENTER_NATIVE_LIB),
+ options.getOutputType());
+ }
+
+ /**
+ * Constructor to initialize the JNI with a pointer from C++.
+ *
+ * @param nativeHandle a pointer referencing memory allocated in C++
+ */
+ private ImageSegmenter(long nativeHandle, OutputType outputType) {
+ super(nativeHandle);
+ this.outputType = outputType;
+ }
+
+ /** Options for setting up an {@link ImageSegmenter}. */
+ @AutoValue
+ public abstract static class ImageSegmenterOptions {
+ private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en";
+ private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
+ private static final int NUM_THREADS = -1;
+
+ public abstract String getDisplayNamesLocale();
+
+ public abstract OutputType getOutputType();
+
+ public abstract int getNumThreads();
+
+ public static Builder builder() {
+ return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
+ .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
+ .setOutputType(DEFAULT_OUTPUT_TYPE)
+ .setNumThreads(NUM_THREADS);
+ }
+
+ /** Builder for {@link ImageSegmenterOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+
+ /**
+ * Sets the locale to use for display names specified through the TFLite Model Metadata, if
+ * any.
+ *
+ * <p>Defaults to English({@code "en"}). See the <a
+ * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite
+ * Metadata schema file.</a> for the accepted pattern of locale.
+ */
+ public abstract Builder setDisplayNamesLocale(String displayNamesLocale);
+
+ public abstract Builder setOutputType(OutputType outputType);
+
+ /**
+ * Sets the number of threads to be used for TFLite ops that support multi-threading when
+ * running inference with CPU. Defaults to -1.
+ *
+ * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
+ * effect to let TFLite runtime set the value.
+ */
+ public abstract Builder setNumThreads(int numThreads);
+
+ public abstract ImageSegmenterOptions build();
+ }
+ }
+
+ /**
+ * Performs actual segmentation on the provided image.
+ *
+ * @param image a {@link TensorImage} object that represents an RGB image
+ * @return results of performing image segmentation. Note that at the time, a single {@link
+ * Segmentation} element is expected to be returned. The result is stored in a {@link List}
+ * for later extension to e.g. instance segmentation models, which may return one segmentation
+ * per object.
+ * @throws AssertionError if error occurs when segmenting the image from the native code
+ */
+ public List<Segmentation> segment(TensorImage image) {
+ return segment(image, ImageProcessingOptions.builder().build());
+ }
+
+ /**
+ * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}.
+ *
+ * @param image a {@link TensorImage} object that represents an RGB image
+ * @param options {@link ImageSegmenter} only supports image rotation (through {@link
+ * ImageProcessingOptions#Builder#setOrientation}) currently. The orientation of an image
+ * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}.
+ * @return results of performing image segmentation. Note that at the time, a single {@link
+ * Segmentation} element is expected to be returned. The result is stored in a {@link List}
+ * for later extension to e.g. instance segmentation models, which may return one segmentation
+ * per object.
+ * @throws AssertionError if error occurs when segmenting the image from the native code
+ */
+ public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) {
+ checkNotClosed();
+
+ // image_segmenter_jni.cc expects an uint8 image. Convert image of other types into uint8.
+ TensorImage imageUint8 =
+ image.getDataType() == DataType.UINT8
+ ? image
+ : TensorImage.createFrom(image, DataType.UINT8);
+ List<byte[]> maskByteArrays = new ArrayList<>();
+ List<ColoredLabel> coloredLabels = new ArrayList<>();
+ int[] maskShape = new int[2];
+ segmentNative(
+ getNativeHandle(),
+ imageUint8.getBuffer(),
+ imageUint8.getWidth(),
+ imageUint8.getHeight(),
+ maskByteArrays,
+ maskShape,
+ coloredLabels,
+ options.getOrientation().getValue());
+
+ List<ByteBuffer> maskByteBuffers = new ArrayList<>();
+ for (byte[] bytes : maskByteArrays) {
+ ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
+ // Change the byte order to little_endian, since the buffers were generated in jni.
+ byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
+ maskByteBuffers.add(byteBuffer);
+ }
+
+ return Arrays.asList(
+ Segmentation.create(
+ outputType,
+ outputType.createMasksFromBuffer(maskByteBuffers, maskShape),
+ coloredLabels));
+ }
+
+ private static ImageSegmenter createFromModelFdAndOptions(
+ final int fileDescriptor,
+ final long fileDescriptorLength,
+ final long fileDescriptorOffset,
+ final ImageSegmenterOptions options) {
+ long nativeHandle =
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithModelFdAndOptions(
+ fileDescriptor,
+ fileDescriptorLength,
+ fileDescriptorOffset,
+ options.getDisplayNamesLocale(),
+ options.getOutputType().getValue(),
+ options.getNumThreads());
+ }
+ },
+ IMAGE_SEGMENTER_NATIVE_LIB);
+ return new ImageSegmenter(nativeHandle, options.getOutputType());
+ }
+
+ private static native long initJniWithModelFdAndOptions(
+ int fileDescriptor,
+ long fileDescriptorLength,
+ long fileDescriptorOffset,
+ String displayNamesLocale,
+ int outputType,
+ int numThreads);
+
+ private static native long initJniWithByteBuffer(
+ ByteBuffer modelBuffer, String displayNamesLocale, int outputType, int numThreads);
+
+ /**
+ * The native method to segment the image.
+ *
+ * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native
+ * layer.
+ */
+ private static native void segmentNative(
+ long nativeHandle,
+ ByteBuffer image,
+ int width,
+ int height,
+ List<byte[]> maskByteArrays,
+ int[] maskShape,
+ List<ColoredLabel> coloredLabels,
+ int orientation);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
+ }
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
new file mode 100644
index 00000000..03d82c6d
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
@@ -0,0 +1,145 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.vision.segmenter;
+
+import static org.tensorflow.lite.DataType.FLOAT32;
+import static org.tensorflow.lite.DataType.UINT8;
+import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
+import static org.tensorflow.lite.support.image.ColorSpaceType.GRAYSCALE;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import org.tensorflow.lite.support.image.TensorImage;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * Output mask type. This allows specifying the type of post-processing to perform on the raw model
+ * results.
+ */
+public enum OutputType {
+
+ /**
+ * Gives a single output mask where each pixel represents the class which the pixel in the
+ * original image was predicted to belong to.
+ */
+ CATEGORY_MASK(0) {
+ /**
+ * {@inheritDoc}
+ *
+ * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if the
+ * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE}
+ */
+ @Override
+ void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
+ checkArgument(
+ masks.size() == 1,
+ "CATRGORY_MASK only allows one TensorImage in the list, providing " + masks.size());
+
+ TensorImage mask = masks.get(0);
+ checkArgument(
+ mask.getColorSpaceType() == GRAYSCALE,
+ "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
+ + mask.getColorSpaceType());
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the list
+ */
+ @Override
+ List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
+ checkArgument(
+ buffers.size() == 1,
+ "CATRGORY_MASK only allows one mask in the buffer list, providing " + buffers.size());
+
+ List<TensorImage> masks = new ArrayList<>();
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8);
+ tensorBuffer.loadBuffer(buffers.get(0), maskShape);
+ TensorImage tensorImage = new TensorImage(UINT8);
+ tensorImage.load(tensorBuffer, GRAYSCALE);
+ masks.add(tensorImage);
+
+ return masks;
+ }
+ },
+
+ /**
+ * Gives a list of output masks where, for each mask, each pixel represents the prediction
+ * confidence, usually in the [0, 1] range.
+ */
+ CONFIDENCE_MASK(1) {
+ /**
+ * {@inheritDoc}
+ *
+ * @throws IllegalArgumentException if more the size of the masks list does not match the size
+ * of the coloredlabels list, or if the color space type of the any mask is not {@link
+ * ColorSpaceType#GRAYSCALE}
+ */
+ @Override
+ void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
+ checkArgument(
+ masks.size() == coloredLabels.size(),
+ String.format(
+ "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of"
+ + " coloredLabels (%d).",
+ masks.size(), coloredLabels.size()));
+
+ for (TensorImage mask : masks) {
+ checkArgument(
+ mask.getColorSpaceType() == GRAYSCALE,
+ "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
+ + mask.getColorSpaceType());
+ }
+ }
+
+ @Override
+ List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
+ List<TensorImage> masks = new ArrayList<>();
+ for (ByteBuffer buffer : buffers) {
+ TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32);
+ tensorBuffer.loadBuffer(buffer, maskShape);
+ TensorImage tensorImage = new TensorImage(FLOAT32);
+ tensorImage.load(tensorBuffer, GRAYSCALE);
+ masks.add(tensorImage);
+ }
+ return masks;
+ }
+ };
+
+ public int getValue() {
+ return value;
+ }
+
+ /**
+ * Verifies that the given list of masks matches the list of colored labels.
+ *
+ * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
+ * output type
+ */
+ abstract void assertMasksMatchColoredLabels(
+ List<TensorImage> masks, List<ColoredLabel> coloredLabels);
+
+ /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */
+ abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape);
+
+ private final int value;
+
+ private OutputType(int value) {
+ this.value = value;
+ }
+}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
new file mode 100644
index 00000000..018482c7
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java
@@ -0,0 +1,82 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.vision.segmenter;
+
+import com.google.auto.value.AutoValue;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.tensorflow.lite.support.image.TensorImage;
+
+/** Represents the segmentation result of an {@link ImageSegmenter}. */
+@AutoValue
+public abstract class Segmentation {
+
+ /**
+ * Creates a {@link Segmentation} object.
+ *
+ * <p>{@link Segmentation} provides two types of outputs as indicated through {@link OutputType}:
+ *
+ * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a
+ * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of each
+ * pixel in this mask represents the class to which the pixel in the mask belongs. The pixel
+ * values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code i} is
+ * associated with {@code coloredLabels.get(i)}.
+ *
+ * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which
+ * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated with
+ * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with
+ * shape (height, width), in row major order. The value of each pixel in these masks represents
+ * the confidence score for this particular class.
+ *
+ * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br>
+ * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code
+ * Orientation} flag of the input FrameBuffer, <br>
+ * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer
+ * dimensions.
+ *
+ * <p>Example of such post-processing, assuming: <br>
+ * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image
+ * will be rotated 90° clockwise during preprocessing to make it "upright"), <br>
+ * \* a model outputting masks of size 224x224. <br>
+ * In order to be directly displayable on top of the input image assumed to be displayed *with*
+ * the {@code Orientation} flag taken into account (according to the <a
+ * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to be:
+ * re-scaled to 640 x 480, then rotated 90° clockwise.
+ *
+ * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
+ * {@code outputType}
+ */
+ static Segmentation create(
+ OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
+ outputType.assertMasksMatchColoredLabels(masks, coloredLabels);
+
+ return new AutoValue_Segmentation(
+ outputType,
+ Collections.unmodifiableList(new ArrayList<TensorImage>(masks)),
+ Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels)));
+ }
+
+ public abstract OutputType getOutputType();
+
+ // As an open source project, we've been trying avoiding depending on common java libraries,
+ // such as Guava, because it may introduce conflicts with clients who also happen to use those
+ // libraries. Therefore, instead of using ImmutableList here, we convert the List into
+ // unmodifiableList in create() to make it less vulnerable.
+ public abstract List<TensorImage> getMasks();
+
+ public abstract List<ColoredLabel> getColoredLabels();
+}
diff --git a/tensorflow_lite_support/java/src/native/task/core/BUILD b/tensorflow_lite_support/java/src/native/task/core/BUILD
new file mode 100644
index 00000000..d4dd7ab3
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/core/BUILD
@@ -0,0 +1,16 @@
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+# Default provider for BuiltInOpResover. Create your own target, overwrite the
+# function to provide a MutableOpResolver for customized OPs and/or a subset of
+# builtin OPs.
+cc_library(
+ name = "builtin_op_resolver",
+ srcs = ["builtin_op_resolver.cc"],
+ deps = [
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc
new file mode 100644
index 00000000..050f49fc
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc
@@ -0,0 +1,27 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/kernels/register.h"
+
+namespace tflite {
+namespace task {
+// Default provider for OpResolver, provides BuiltinOpResolver.
+std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT
+ return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
+ new tflite::ops::builtin::BuiltinOpResolver);
+}
+
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/java/src/native/task/text/BUILD b/tensorflow_lite_support/java/src/native/task/text/BUILD
new file mode 100644
index 00000000..a27aba52
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/text/BUILD
@@ -0,0 +1,34 @@
+load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "task_text_native",
+ srcs = [
+ ":libtask_text_jni.so",
+ ],
+)
+
+tflite_jni_binary(
+ name = "libtask_text_jni.so",
+ srcs = [
+ "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni.cc",
+ "//tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier:bert_nl_classifier_jni.cc",
+ "//tensorflow_lite_support/java/src/native/task/text/qa:bert_question_answerer_jni.cc",
+ ],
+ linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
+ deps = [
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier",
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier",
+ "//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver",
+ "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni_utils",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD
new file mode 100644
index 00000000..88f3e9be
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD
@@ -0,0 +1,63 @@
+load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "nl_classifier_jni.cc",
+])
+
+# Default native target for nl_classifier to provide BuiltInOpResolver.
+cc_library(
+ name = "nl_classifier_native",
+ srcs = [
+ ":libtask_text_jni.so",
+ ],
+)
+
+# Note: "libtask_text_jni" is hardcoded in Java to look up the .so, therefore
+# the name should remain the same when creating customized version of
+# nl_classifier_native
+tflite_jni_binary(
+ name = "libtask_text_jni.so",
+ linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
+ deps = [
+ ":native_without_resolver",
+ "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver",
+ ],
+)
+
+# Shared native logic for NLClassifier. Combine this target and customized
+# version of op_resolver to build customized nl_classifier_native target.
+cc_library(
+ name = "native_without_resolver",
+ srcs = [
+ "nl_classifier_jni.cc",
+ ],
+ deps = [
+ ":nl_classifier_jni_utils",
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ "@org_tensorflow//tensorflow/lite:framework",
+ "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "nl_classifier_jni_utils",
+ srcs = [
+ "nl_classifier_jni_utils.cc",
+ ],
+ hdrs = [
+ "nl_classifier_jni_utils.h",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD
new file mode 100644
index 00000000..49f3f4e4
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD
@@ -0,0 +1,31 @@
+load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "bert_nl_classifier_jni.cc",
+])
+
+cc_library(
+ name = "bert_nl_classifier_native",
+ srcs = [
+ ":libtask_text_jni.so",
+ ],
+)
+
+tflite_jni_binary(
+ name = "libtask_text_jni.so",
+ srcs = [
+ "bert_nl_classifier_jni.cc",
+ ],
+ linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
+ deps = [
+ "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni_utils",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
new file mode 100644
index 00000000..1edb3507
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
@@ -0,0 +1,74 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+#include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h"
+
+namespace {
+
+using ::tflite::support::utils::GetMappedFileBuffer;
+using ::tflite::support::utils::kAssertionError;
+using ::tflite::support::utils::kInvalidPointer;
+using ::tflite::support::utils::ThrowException;
+using ::tflite::task::text::nlclassifier::BertNLClassifier;
+using ::tflite::task::text::nlclassifier::RunClassifier;
+
+extern "C" JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni(
+ JNIEnv* env, jobject thiz, jlong native_handle) {
+ delete reinterpret_cast<BertNLClassifier*>(native_handle);
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer(
+ JNIEnv* env, jclass thiz, jobject model_buffer) {
+ auto model = GetMappedFileBuffer(env, model_buffer);
+ tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status =
+ BertNLClassifier::CreateFromBuffer(model.data(), model.size());
+ if (status.ok()) {
+ return reinterpret_cast<jlong>(status->release());
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when initializing Bert NLClassifier: %s",
+ status.status().message().data());
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor(
+ JNIEnv* env, jclass thiz, jint fd) {
+ tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status =
+ BertNLClassifier::CreateFromFd(fd);
+ if (status.ok()) {
+ return reinterpret_cast<jlong>(status->release());
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when initializing Bert NLClassifier: %s",
+ status.status().message().data());
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT jobject JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative(
+ JNIEnv* env, jclass clazz, jlong native_handle, jstring text) {
+ return RunClassifier(env, native_handle, text);
+}
+
+} // namespace
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
new file mode 100644
index 00000000..d2ace753
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc
@@ -0,0 +1,135 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/op_resolver.h"
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+#include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h"
+
+namespace tflite {
+namespace task {
+// To be provided by a link-time library
+extern std::unique_ptr<OpResolver> CreateOpResolver();
+
+} // namespace task
+} // namespace tflite
+
+namespace {
+
+using ::tflite::support::utils::kAssertionError;
+using ::tflite::support::utils::kInvalidPointer;
+using ::tflite::support::utils::GetMappedFileBuffer;
+using ::tflite::support::utils::JStringToString;
+using ::tflite::support::utils::ThrowException;
+using ::tflite::task::text::nlclassifier::NLClassifier;
+using ::tflite::task::text::nlclassifier::NLClassifierOptions;
+using ::tflite::task::text::nlclassifier::RunClassifier;
+
+
+NLClassifierOptions ConvertJavaNLClassifierOptions(
+ JNIEnv* env, jobject java_nl_classifier_options) {
+ jclass nl_classifier_options_class = env->FindClass(
+ "org/tensorflow/lite/task/text/nlclassifier/"
+ "NLClassifier$NLClassifierOptions");
+ jmethodID input_tensor_index_method_id =
+ env->GetMethodID(nl_classifier_options_class, "inputTensorIndex", "()I");
+ jmethodID output_score_tensor_index_method_id = env->GetMethodID(
+ nl_classifier_options_class, "outputScoreTensorIndex", "()I");
+ jmethodID output_label_tensor_index_method_id = env->GetMethodID(
+ nl_classifier_options_class, "outputLabelTensorIndex", "()I");
+ jmethodID input_tensor_name_method_id = env->GetMethodID(
+ nl_classifier_options_class, "inputTensorName", "()Ljava/lang/String;");
+ jmethodID output_score_tensor_name_method_id =
+ env->GetMethodID(nl_classifier_options_class, "outputScoreTensorName",
+ "()Ljava/lang/String;");
+ jmethodID output_label_tensor_name_method_id =
+ env->GetMethodID(nl_classifier_options_class, "outputLabelTensorName",
+ "()Ljava/lang/String;");
+
+ return {
+ .input_tensor_index = env->CallIntMethod(java_nl_classifier_options,
+ input_tensor_index_method_id),
+ .output_score_tensor_index = env->CallIntMethod(
+ java_nl_classifier_options, output_score_tensor_index_method_id),
+ .output_label_tensor_index = env->CallIntMethod(
+ java_nl_classifier_options, output_label_tensor_index_method_id),
+ .input_tensor_name = JStringToString(
+ env, (jstring)env->CallObjectMethod(java_nl_classifier_options,
+ input_tensor_name_method_id)),
+ .output_score_tensor_name = JStringToString(
+ env,
+ (jstring)env->CallObjectMethod(java_nl_classifier_options,
+ output_score_tensor_name_method_id)),
+ .output_label_tensor_name = JStringToString(
+ env,
+ (jstring)env->CallObjectMethod(java_nl_classifier_options,
+ output_label_tensor_name_method_id)),
+ };
+}
+
+extern "C" JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni(
+ JNIEnv* env, jobject thiz, jlong native_handle) {
+ delete reinterpret_cast<NLClassifier*>(native_handle);
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer(
+ JNIEnv* env, jclass thiz, jobject nl_classifier_options,
+ jobject model_buffer) {
+ auto model = GetMappedFileBuffer(env, model_buffer);
+ tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status =
+ NLClassifier::CreateFromBufferAndOptions(
+ model.data(), model.size(),
+ ConvertJavaNLClassifierOptions(env, nl_classifier_options),
+ tflite::task::CreateOpResolver());
+
+ if (status.ok()) {
+ return reinterpret_cast<jlong>(status->release());
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when initializing NLClassifier: %s",
+ status.status().message().data());
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor(
+ JNIEnv* env, jclass thiz, jobject nl_classifier_options, jint fd) {
+ tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status =
+ NLClassifier::CreateFromFdAndOptions(
+ fd, ConvertJavaNLClassifierOptions(env, nl_classifier_options),
+ tflite::task::CreateOpResolver());
+ if (status.ok()) {
+ return reinterpret_cast<jlong>(status->release());
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when initializing NLClassifier: %s",
+ status.status().message().data());
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT jobject JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative(
+ JNIEnv* env, jclass thiz, jlong native_handle, jstring text) {
+ return RunClassifier(env, native_handle, text);
+}
+
+} // namespace
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
new file mode 100644
index 00000000..c358bee1
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
@@ -0,0 +1,56 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+
+#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace nlclassifier {
+
+using ::tflite::support::utils::ConvertVectorToArrayList;
+using ::tflite::support::utils::JStringToString;
+using ::tflite::task::core::Category;
+using ::tflite::task::text::nlclassifier::NLClassifier;
+
+jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) {
+ auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle);
+
+ auto results = nl_classifier->Classify(JStringToString(env, text));
+ jclass category_class =
+ env->FindClass("org/tensorflow/lite/support/label/Category");
+ jmethodID category_init =
+ env->GetMethodID(category_class, "<init>", "(Ljava/lang/String;F)V");
+
+ return ConvertVectorToArrayList<Category>(
+ env, results,
+ [env, category_class, category_init](const Category& category) {
+ jstring class_name = env->NewStringUTF(category.class_name.data());
+ // Convert double to float as Java interface exposes float as scores.
+ jobject jcategory =
+ env->NewObject(category_class, category_init, class_name,
+ static_cast<float>(category.score));
+ env->DeleteLocalRef(class_name);
+ return jcategory;
+ });
+}
+
+} // namespace nlclassifier
+} // namespace text
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h
new file mode 100644
index 00000000..2c59ab50
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h
@@ -0,0 +1,33 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TAKS_TEXT_NLCLASSIFIER_NL_CLASSIFIER_JNI_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TAKS_TEXT_NLCLASSIFIER_NL_CLASSIFIER_JNI_UTILS_H_
+
+#include <jni.h>
+
+namespace tflite {
+namespace task {
+namespace text {
+namespace nlclassifier {
+
+jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text);
+
+} // namespace nlclassifier
+} // namespace text
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TAKS_TEXT_NLCLASSIFIER_NL_CLASSIFIER_JNI_UTILS_H_
diff --git a/tensorflow_lite_support/java/src/native/task/text/qa/BUILD b/tensorflow_lite_support/java/src/native/task/text/qa/BUILD
new file mode 100644
index 00000000..9753e329
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/text/qa/BUILD
@@ -0,0 +1,30 @@
+load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files([
+ "bert_question_answerer_jni.cc",
+])
+
+tflite_jni_binary(
+ name = "libtask_text_jni.so",
+ srcs = [
+ "bert_question_answerer_jni.cc",
+ ],
+ linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
+ deps = [
+ "//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ ],
+)
+
+cc_library(
+ name = "bert_question_answerer_native",
+ srcs = [
+ ":libtask_text_jni.so",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc b/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
new file mode 100644
index 00000000..92f93467
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc
@@ -0,0 +1,127 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+
+#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+
+namespace {
+
+using ::tflite::support::utils::ConvertVectorToArrayList;
+using ::tflite::support::utils::GetMappedFileBuffer;
+using ::tflite::support::utils::JStringToString;
+using ::tflite::task::text::qa::BertQuestionAnswerer;
+using ::tflite::task::text::qa::QaAnswer;
+using ::tflite::task::text::qa::QuestionAnswerer;
+
+constexpr int kInvalidPointer = 0;
+
+extern "C" JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(
+ JNIEnv* env, jobject thiz, jlong native_handle) {
+ delete reinterpret_cast<QuestionAnswerer*>(native_handle);
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithModelWithMetadataByteBuffers(
+ JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
+ absl::string_view model_with_metadata =
+ GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
+
+ tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
+ BertQuestionAnswerer::CreateFromBuffer(
+ model_with_metadata.data(), model_with_metadata.size());
+ if (status.ok()) {
+ return reinterpret_cast<jlong>(status->release());
+ } else {
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(
+ JNIEnv* env, jclass thiz, jint fd) {
+ tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
+ BertQuestionAnswerer::CreateFromFd(fd);
+ if (status.ok()) {
+ return reinterpret_cast<jlong>(status->release());
+ } else {
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
+ JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
+ absl::string_view model =
+ GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
+ absl::string_view vocab =
+ GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1));
+
+ tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
+ BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
+ model.data(), model.size(), vocab.data(), vocab.size());
+ if (status.ok()) {
+ return reinterpret_cast<jlong>(status->release());
+ } else {
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers(
+ JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
+ absl::string_view model =
+ GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
+ absl::string_view sp_model =
+ GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1));
+
+ tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
+ BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
+ model.data(), model.size(), sp_model.data(), sp_model.size());
+ if (status.ok()) {
+ return reinterpret_cast<jlong>(status->release());
+ } else {
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT jobject JNICALL
+Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
+ JNIEnv* env, jclass thiz, jlong native_handle, jstring context,
+ jstring question) {
+ auto* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
+
+ std::vector<QaAnswer> results = question_answerer->Answer(
+ JStringToString(env, context), JStringToString(env, question));
+ jclass qa_answer_class =
+ env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
+ jmethodID qa_answer_ctor =
+ env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
+
+ return ConvertVectorToArrayList<QaAnswer>(
+ env, results,
+ [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
+ jstring text = env->NewStringUTF(ans.text.data());
+ jobject qa_answer =
+ env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
+ ans.pos.end, ans.pos.logit);
+ env->DeleteLocalRef(text);
+ return qa_answer;
+ });
+}
+
+} // namespace
diff --git a/tensorflow_lite_support/java/src/native/task/vision/BUILD b/tensorflow_lite_support/java/src/native/task/vision/BUILD
new file mode 100644
index 00000000..451a50ca
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/vision/BUILD
@@ -0,0 +1,59 @@
+load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "jni_utils",
+ srcs = [
+ "jni_utils.cc",
+ ],
+ hdrs = [
+ "jni_utils.h",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "task_vision_native",
+ srcs = [
+ ":libtask_vision_jni.so",
+ ],
+)
+
+tflite_jni_binary(
+ name = "libtask_vision_jni.so",
+ srcs = [
+ "//tensorflow_lite_support/java/src/native/task/vision/classifier:image_classifier_jni.cc",
+ "//tensorflow_lite_support/java/src/native/task/vision/detector:object_detector_jni.cc",
+ "//tensorflow_lite_support/java/src/native/task/vision/segmenter:image_segmenter_jni.cc",
+ ],
+ linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/vision:image_classifier",
+ "//tensorflow_lite_support/cc/task/vision:image_segmenter",
+ "//tensorflow_lite_support/cc/task/vision:object_detector",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ "//tensorflow_lite_support/java/src/native/task/vision:jni_utils",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD b/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD
new file mode 100644
index 00000000..8bddc2ac
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD
@@ -0,0 +1,35 @@
+load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(["image_classifier_jni.cc"])
+
+cc_library(
+ name = "image_classifier_native",
+ srcs = [
+ ":libtask_vision_jni.so",
+ ],
+)
+
+tflite_jni_binary(
+ name = "libtask_vision_jni.so",
+ srcs = [
+ "image_classifier_jni.cc",
+ ],
+ linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/vision:image_classifier",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ "//tensorflow_lite_support/java/src/native/task/vision:jni_utils",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
new file mode 100644
index 00000000..2d52f937
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc
@@ -0,0 +1,234 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+
+#include <memory>
+#include <string>
+
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h"
+
+namespace {
+
+using ::tflite::support::StatusOr;
+using ::tflite::support::utils::GetMappedFileBuffer;
+using ::tflite::support::utils::kAssertionError;
+using ::tflite::support::utils::kInvalidPointer;
+using ::tflite::support::utils::StringListToVector;
+using ::tflite::support::utils::ThrowException;
+using ::tflite::task::vision::BoundingBox;
+using ::tflite::task::vision::ClassificationResult;
+using ::tflite::task::vision::Classifications;
+using ::tflite::task::vision::ConvertToCategory;
+using ::tflite::task::vision::ConvertToFrameBufferOrientation;
+using ::tflite::task::vision::FrameBuffer;
+using ::tflite::task::vision::ImageClassifier;
+using ::tflite::task::vision::ImageClassifierOptions;
+
+// Creates an ImageClassifierOptions proto based on the Java class.
+ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env,
+ jobject java_options) {
+ ImageClassifierOptions proto_options;
+ jclass java_options_class = env->FindClass(
+ "org/tensorflow/lite/task/vision/classifier/"
+ "ImageClassifier$ImageClassifierOptions");
+
+ jmethodID display_names_locale_id = env->GetMethodID(
+ java_options_class, "getDisplayNamesLocale", "()Ljava/lang/String;");
+ jstring display_names_locale = static_cast<jstring>(
+ env->CallObjectMethod(java_options, display_names_locale_id));
+ const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr);
+ proto_options.set_display_names_locale(pchars);
+ env->ReleaseStringUTFChars(display_names_locale, pchars);
+
+ jmethodID max_results_id =
+ env->GetMethodID(java_options_class, "getMaxResults", "()I");
+ jint max_results = env->CallIntMethod(java_options, max_results_id);
+ proto_options.set_max_results(max_results);
+
+ jmethodID is_score_threshold_set_id =
+ env->GetMethodID(java_options_class, "getIsScoreThresholdSet", "()Z");
+ jboolean is_score_threshold_set =
+ env->CallBooleanMethod(java_options, is_score_threshold_set_id);
+ if (is_score_threshold_set) {
+ jmethodID score_threshold_id =
+ env->GetMethodID(java_options_class, "getScoreThreshold", "()F");
+ jfloat score_threshold =
+ env->CallFloatMethod(java_options, score_threshold_id);
+ proto_options.set_score_threshold(score_threshold);
+ }
+
+ jmethodID allow_list_id = env->GetMethodID(
+ java_options_class, "getLabelAllowList", "()Ljava/util/List;");
+ jobject allow_list = env->CallObjectMethod(java_options, allow_list_id);
+ auto allow_list_vector = StringListToVector(env, allow_list);
+ for (const auto& class_name : allow_list_vector) {
+ proto_options.add_class_name_whitelist(class_name);
+ }
+
+ jmethodID deny_list_id = env->GetMethodID(
+ java_options_class, "getLabelDenyList", "()Ljava/util/List;");
+ jobject deny_list = env->CallObjectMethod(java_options, deny_list_id);
+ auto deny_list_vector = StringListToVector(env, deny_list);
+ for (const auto& class_name : deny_list_vector) {
+ proto_options.add_class_name_blacklist(class_name);
+ }
+
+ jmethodID num_threads_id =
+ env->GetMethodID(java_options_class, "getNumThreads", "()I");
+ jint num_threads = env->CallIntMethod(java_options, num_threads_id);
+ proto_options.set_num_threads(num_threads);
+
+ return proto_options;
+}
+
+jobject ConvertToClassificationResults(JNIEnv* env,
+ const ClassificationResult& results) {
+ // jclass and init of Classifications.
+ jclass classifications_class = env->FindClass(
+ "org/tensorflow/lite/task/vision/classifier/Classifications");
+ jmethodID classifications_create =
+ env->GetStaticMethodID(classifications_class, "create",
+ "(Ljava/util/List;I)Lorg/tensorflow/lite/"
+ "task/vision/classifier/Classifications;");
+
+ // jclass, init, and add of ArrayList.
+ jclass array_list_class = env->FindClass("java/util/ArrayList");
+ jmethodID array_list_init =
+ env->GetMethodID(array_list_class, "<init>", "(I)V");
+ jmethodID array_list_add_method =
+ env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z");
+
+ jobject classifications_list =
+ env->NewObject(array_list_class, array_list_init,
+ static_cast<jint>(results.classifications_size()));
+ for (int i = 0; i < results.classifications_size(); i++) {
+ auto classifications = results.classifications(i);
+ jobject jcategory_list = env->NewObject(array_list_class, array_list_init,
+ classifications.classes_size());
+ for (const auto& classification : classifications.classes()) {
+ jobject jcategory = ConvertToCategory(env, classification);
+ env->CallBooleanMethod(jcategory_list, array_list_add_method, jcategory);
+
+ env->DeleteLocalRef(jcategory);
+ }
+ jobject jclassifications = env->CallStaticObjectMethod(
+ classifications_class, classifications_create, jcategory_list,
+ classifications.head_index());
+ env->CallBooleanMethod(classifications_list, array_list_add_method,
+ jclassifications);
+
+ env->DeleteLocalRef(jcategory_list);
+ env->DeleteLocalRef(jclassifications);
+ }
+ return classifications_list;
+}
+
+jlong CreateImageClassifierFromOptions(JNIEnv* env,
+ const ImageClassifierOptions& options) {
+ StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
+ ImageClassifier::CreateFromOptions(options);
+ if (image_classifier_or.ok()) {
+ // Deletion is handled at deinitJni time.
+ return reinterpret_cast<jlong>(image_classifier_or->release());
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when initializing ImageClassifier: %s",
+ image_classifier_or.status().message().data());
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni(
+ JNIEnv* env, jobject thiz, jlong native_handle) {
+ delete reinterpret_cast<ImageClassifier*>(native_handle);
+}
+
+// Creates an ImageClassifier instance from the model file descriptor.
+// file_descriptor_length and file_descriptor_offset are optional. Non-possitive
+// values will be ignored.
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModelFdAndOptions(
+ JNIEnv* env, jclass thiz, jint file_descriptor,
+ jlong file_descriptor_length, jlong file_descriptor_offset,
+ jobject java_options) {
+ ImageClassifierOptions proto_options =
+ ConvertToProtoOptions(env, java_options);
+ auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata()
+ ->mutable_file_descriptor_meta();
+ file_descriptor_meta->set_fd(file_descriptor);
+ if (file_descriptor_length > 0) {
+ file_descriptor_meta->set_length(file_descriptor_length);
+ }
+ if (file_descriptor_offset > 0) {
+ file_descriptor_meta->set_offset(file_descriptor_offset);
+ }
+ return CreateImageClassifierFromOptions(env, proto_options);
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteBuffer(
+ JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) {
+ ImageClassifierOptions proto_options =
+ ConvertToProtoOptions(env, java_options);
+ // External proto generated header does not overload `set_file_content` with
+ // string_view, therefore GetMappedFileBuffer does not apply here.
+ // Creating a std::string will cause one extra copying of data. Thus, the
+ // most efficient way here is to set file_content using char* and its size.
+ proto_options.mutable_model_file_with_metadata()->set_file_content(
+ static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
+ static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
+ return CreateImageClassifierFromOptions(env, proto_options);
+}
+
+extern "C" JNIEXPORT jobject JNICALL
+Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_classifyNative(
+ JNIEnv* env, jclass thiz, jlong native_handle, jobject image_byte_buffer,
+ jint width, jint height, jintArray jroi, jint jorientation) {
+ auto* classifier = reinterpret_cast<ImageClassifier*>(native_handle);
+ auto image = GetMappedFileBuffer(env, image_byte_buffer);
+ std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
+ reinterpret_cast<const uint8*>(image.data()),
+ FrameBuffer::Dimension{width, height},
+ ConvertToFrameBufferOrientation(env, jorientation));
+
+ int* roi_array = env->GetIntArrayElements(jroi, 0);
+ BoundingBox roi;
+ roi.set_origin_x(roi_array[0]);
+ roi.set_origin_y(roi_array[1]);
+ roi.set_width(roi_array[2]);
+ roi.set_height(roi_array[3]);
+ env->ReleaseIntArrayElements(jroi, roi_array, 0);
+
+ auto results_or = classifier->Classify(*frame_buffer, roi);
+ if (results_or.ok()) {
+ return ConvertToClassificationResults(env, results_or.value());
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when classifying the image: %s",
+ results_or.status().message().data());
+ return nullptr;
+ }
+}
+} // namespace
diff --git a/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD b/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD
new file mode 100644
index 00000000..5abd3f1a
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD
@@ -0,0 +1,36 @@
+load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(["object_detector_jni.cc"])
+
+cc_library(
+ name = "object_detector_native",
+ srcs = [
+ ":libtask_vision_jni.so",
+ ],
+)
+
+tflite_jni_binary(
+ name = "libtask_vision_jni.so",
+ srcs = [
+ "object_detector_jni.cc",
+ ],
+ linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/vision:object_detector",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ "//tensorflow_lite_support/java/src/native/task/vision:jni_utils",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc b/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
new file mode 100644
index 00000000..016b0bfd
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc
@@ -0,0 +1,228 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/object_detector.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h"
+
+namespace {
+
+using ::tflite::support::StatusOr;
+using ::tflite::support::utils::GetMappedFileBuffer;
+using ::tflite::support::utils::kAssertionError;
+using ::tflite::support::utils::kInvalidPointer;
+using ::tflite::support::utils::StringListToVector;
+using ::tflite::support::utils::ThrowException;
+using ::tflite::task::vision::BoundingBox;
+using ::tflite::task::vision::ConvertToCategory;
+using ::tflite::task::vision::ConvertToFrameBufferOrientation;
+using ::tflite::task::vision::DetectionResult;
+using ::tflite::task::vision::FrameBuffer;
+using ::tflite::task::vision::ObjectDetector;
+using ::tflite::task::vision::ObjectDetectorOptions;
+
+// Creates an ObjectDetectorOptions proto based on the Java class.
+ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options) {
+ ObjectDetectorOptions proto_options;
+ jclass java_options_class = env->FindClass(
+ "org/tensorflow/lite/task/vision/detector/"
+ "ObjectDetector$ObjectDetectorOptions");
+
+ jmethodID display_names_locale_id = env->GetMethodID(
+ java_options_class, "getDisplayNamesLocale", "()Ljava/lang/String;");
+ jstring display_names_locale = static_cast<jstring>(
+ env->CallObjectMethod(java_options, display_names_locale_id));
+ const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr);
+ proto_options.set_display_names_locale(pchars);
+ env->ReleaseStringUTFChars(display_names_locale, pchars);
+
+ jmethodID max_results_id =
+ env->GetMethodID(java_options_class, "getMaxResults", "()I");
+ jint max_results = env->CallIntMethod(java_options, max_results_id);
+ proto_options.set_max_results(max_results);
+
+ jmethodID is_score_threshold_set_id =
+ env->GetMethodID(java_options_class, "getIsScoreThresholdSet", "()Z");
+ jboolean is_score_threshold_set =
+ env->CallBooleanMethod(java_options, is_score_threshold_set_id);
+ if (is_score_threshold_set) {
+ jmethodID score_threshold_id =
+ env->GetMethodID(java_options_class, "getScoreThreshold", "()F");
+ jfloat score_threshold =
+ env->CallFloatMethod(java_options, score_threshold_id);
+ proto_options.set_score_threshold(score_threshold);
+ }
+
+ jmethodID allow_list_id = env->GetMethodID(
+ java_options_class, "getLabelAllowList", "()Ljava/util/List;");
+ jobject allow_list = env->CallObjectMethod(java_options, allow_list_id);
+ std::vector<std::string> allow_list_vector =
+ StringListToVector(env, allow_list);
+ for (const auto& class_name : allow_list_vector) {
+ proto_options.add_class_name_whitelist(class_name);
+ }
+
+ jmethodID deny_list_id = env->GetMethodID(
+ java_options_class, "getLabelDenyList", "()Ljava/util/List;");
+ jobject deny_list = env->CallObjectMethod(java_options, deny_list_id);
+ auto deny_list_vector = StringListToVector(env, deny_list);
+ for (const auto& class_name : deny_list_vector) {
+ proto_options.add_class_name_blacklist(class_name);
+ }
+
+ jmethodID num_threads_id =
+ env->GetMethodID(java_options_class, "getNumThreads", "()I");
+ jint num_threads = env->CallIntMethod(java_options, num_threads_id);
+ proto_options.set_num_threads(num_threads);
+
+ return proto_options;
+}
+
+jobject ConvertToDetectionResults(JNIEnv* env, const DetectionResult& results) {
+ // jclass and init of Detection.
+ jclass detection_class =
+ env->FindClass("org/tensorflow/lite/task/vision/detector/Detection");
+ jmethodID detection_create = env->GetStaticMethodID(
+ detection_class, "create",
+ "(Landroid/graphics/RectF;Ljava/util/List;)Lorg/tensorflow/lite/"
+ "task/vision/detector/Detection;");
+
+ // jclass, init, and add of ArrayList.
+ jclass array_list_class = env->FindClass("java/util/ArrayList");
+ jmethodID array_list_init =
+ env->GetMethodID(array_list_class, "<init>", "(I)V");
+ jmethodID array_list_add_method =
+ env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z");
+
+ // jclass, init of RectF.
+ jclass rectf_class = env->FindClass("android/graphics/RectF");
+ jmethodID rectf_init = env->GetMethodID(rectf_class, "<init>", "(FFFF)V");
+
+ jobject detections_list =
+ env->NewObject(array_list_class, array_list_init,
+ static_cast<jint>(results.detections_size()));
+
+ for (const auto& detection : results.detections()) {
+ // Create the category list.
+ jobject category_list = env->NewObject(array_list_class, array_list_init,
+ detection.classes_size());
+ for (const auto& classification : detection.classes()) {
+ jobject jcategory = ConvertToCategory(env, classification);
+ env->CallBooleanMethod(category_list, array_list_add_method, jcategory);
+ }
+
+ // Create the bounding box object.
+ const BoundingBox& bounding_box = detection.bounding_box();
+ float left = static_cast<float>(bounding_box.origin_x());
+ float top = static_cast<float>(bounding_box.origin_y());
+ float right = static_cast<float>(left + bounding_box.width());
+ float bottom = static_cast<float>(top + bounding_box.height());
+ jobject jbounding_box =
+ env->NewObject(rectf_class, rectf_init, left, top, right, bottom);
+
+ // Create the java Detection object.
+ jobject jdetection = env->CallStaticObjectMethod(
+ detection_class, detection_create, jbounding_box, category_list);
+ env->CallBooleanMethod(detections_list, array_list_add_method, jdetection);
+ }
+ return detections_list;
+}
+
+jlong CreateObjectDetectorFromOptions(JNIEnv* env,
+ const ObjectDetectorOptions& options) {
+ StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
+ ObjectDetector::CreateFromOptions(options);
+ if (object_detector_or.ok()) {
+ return reinterpret_cast<jlong>(object_detector_or->release());
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when initializing ObjectDetector: %s",
+ object_detector_or.status().message().data());
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni(
+ JNIEnv* env, jobject thiz, jlong native_handle) {
+ delete reinterpret_cast<ObjectDetector*>(native_handle);
+}
+
+// Creates an ObjectDetector instance from the model file descriptor.
+// file_descriptor_length and file_descriptor_offset are optional. Non-possitive
+// values will be ignored.
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdAndOptions(
+ JNIEnv* env, jclass thiz, jint file_descriptor,
+ jlong file_descriptor_length, jlong file_descriptor_offset,
+ jobject java_options) {
+ ObjectDetectorOptions proto_options =
+ ConvertToProtoOptions(env, java_options);
+ auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata()
+ ->mutable_file_descriptor_meta();
+ file_descriptor_meta->set_fd(file_descriptor);
+ if (file_descriptor_length > 0) {
+ file_descriptor_meta->set_length(file_descriptor_length);
+ }
+ if (file_descriptor_offset > 0) {
+ file_descriptor_meta->set_offset(file_descriptor_offset);
+ }
+ return CreateObjectDetectorFromOptions(env, proto_options);
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuffer(
+ JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) {
+ ObjectDetectorOptions proto_options =
+ ConvertToProtoOptions(env, java_options);
+ proto_options.mutable_model_file_with_metadata()->set_file_content(
+ static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
+ static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
+ return CreateObjectDetectorFromOptions(env, proto_options);
+}
+
+extern "C" JNIEXPORT jobject JNICALL
+Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_detectNative(
+ JNIEnv* env, jclass thiz, jlong native_handle, jobject image_byte_buffer,
+ jint width, jint height, jint jorientation) {
+ auto* detector = reinterpret_cast<ObjectDetector*>(native_handle);
+ absl::string_view image = GetMappedFileBuffer(env, image_byte_buffer);
+ std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
+ reinterpret_cast<const uint8*>(image.data()),
+ FrameBuffer::Dimension{width, height},
+ ConvertToFrameBufferOrientation(env, jorientation));
+ auto results_or = detector->Detect(*frame_buffer);
+ if (results_or.ok()) {
+ return ConvertToDetectionResults(env, results_or.value());
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when detecting the image: %s",
+ results_or.status().message().data());
+ return nullptr;
+ }
+}
+} // namespace
diff --git a/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc b/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
new file mode 100644
index 00000000..af5dad96
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc
@@ -0,0 +1,85 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h"
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+using ::tflite::support::utils::kAssertionError;
+using ::tflite::support::utils::ThrowException;
+
+constexpr char kCategoryClassName[] =
+ "org/tensorflow/lite/support/label/Category";
+constexpr char kStringClassName[] = "Ljava/lang/String;";
+constexpr char kEmptyString[] = "";
+
+jobject ConvertToCategory(JNIEnv* env, const Class& classification) {
+ // jclass and init of Category.
+ jclass category_class = env->FindClass(kCategoryClassName);
+ jmethodID category_create = env->GetStaticMethodID(
+ category_class, "create",
+ absl::StrCat("(", kStringClassName, kStringClassName, "F)L",
+ kCategoryClassName, ";")
+ .c_str());
+
+ std::string label_string = classification.has_class_name()
+ ? classification.class_name()
+ : std::to_string(classification.index());
+ jstring label = env->NewStringUTF(label_string.c_str());
+ std::string display_name_string = classification.has_display_name()
+ ? classification.display_name()
+ : kEmptyString;
+ jstring display_name = env->NewStringUTF(display_name_string.c_str());
+ jobject jcategory =
+ env->CallStaticObjectMethod(category_class, category_create, label,
+ display_name, classification.score());
+ return jcategory;
+}
+
+FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env,
+ jint jorientation) {
+ switch (jorientation) {
+ case 0:
+ return FrameBuffer::Orientation::kTopLeft;
+ case 1:
+ return FrameBuffer::Orientation::kTopRight;
+ case 2:
+ return FrameBuffer::Orientation::kBottomRight;
+ case 3:
+ return FrameBuffer::Orientation::kBottomLeft;
+ case 4:
+ return FrameBuffer::Orientation::kLeftTop;
+ case 5:
+ return FrameBuffer::Orientation::kRightTop;
+ case 6:
+ return FrameBuffer::Orientation::kRightBottom;
+ case 7:
+ return FrameBuffer::Orientation::kLeftBottom;
+ }
+ // Should never happen.
+ ThrowException(env, kAssertionError,
+ "The FrameBuffer Orientation type is unsupported: %d",
+ jorientation);
+ return FrameBuffer::Orientation::kTopLeft;
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
diff --git a/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h b/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
new file mode 100644
index 00000000..7cb63f31
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TASK_VISION_JNI_UTILS_H_
+#define TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TASK_VISION_JNI_UTILS_H_
+
+#include <jni.h>
+
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Creates a Java Category object based on Class.
+jobject ConvertToCategory(JNIEnv* env, const Class& classification);
+
+FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env,
+ jint jorientation);
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TASK_VISION_JNI_UTILS_H_
diff --git a/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD b/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD
new file mode 100644
index 00000000..23d601df
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD
@@ -0,0 +1,34 @@
+load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(["image_segmenter_jni.cc"])
+
+cc_library(
+ name = "image_segmenter_native",
+ srcs = [
+ ":libtask_vision_jni.so",
+ ],
+)
+
+tflite_jni_binary(
+ name = "libtask_vision_jni.so",
+ srcs = [
+ "image_segmenter_jni.cc",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/cc/task/vision:image_segmenter",
+ "//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
+ "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc",
+ "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils",
+ "//tensorflow_lite_support/cc/utils:jni_utils",
+ "//tensorflow_lite_support/java/jni",
+ "//tensorflow_lite_support/java/src/native/task/vision:jni_utils",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc b/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
new file mode 100644
index 00000000..1f6a2dc3
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc
@@ -0,0 +1,238 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <jni.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
+#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+#include "tensorflow_lite_support/cc/utils/jni_utils.h"
+#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h"
+
+namespace {
+
+using ::tflite::support::StatusOr;
+using ::tflite::support::utils::CreateByteArray;
+using ::tflite::support::utils::GetMappedFileBuffer;
+using ::tflite::support::utils::kAssertionError;
+using ::tflite::support::utils::kIllegalArgumentException;
+using ::tflite::support::utils::kInvalidPointer;
+using ::tflite::support::utils::ThrowException;
+using ::tflite::task::vision::ConvertToFrameBufferOrientation;
+using ::tflite::task::vision::FrameBuffer;
+using ::tflite::task::vision::ImageSegmenter;
+using ::tflite::task::vision::ImageSegmenterOptions;
+using ::tflite::task::vision::Segmentation;
+using ::tflite::task::vision::SegmentationResult;
+
+constexpr char kArrayListClassNameNoSig[] = "java/util/ArrayList";
+constexpr char kObjectClassName[] = "Ljava/lang/Object;";
+constexpr char kColorClassName[] = "Landroid/graphics/Color;";
+constexpr char kColorClassNameNoSig[] = "android/graphics/Color";
+constexpr char kColoredLabelClassName[] =
+ "Lorg/tensorflow/lite/task/vision/segmenter/ColoredLabel;";
+constexpr char kColoredLabelClassNameNoSig[] =
+ "org/tensorflow/lite/task/vision/segmenter/ColoredLabel";
+constexpr char kStringClassName[] = "Ljava/lang/String;";
+constexpr int kOutputTypeCategoryMask = 0;
+constexpr int kOutputTypeConfidenceMask = 1;
+
+// Creates an ImageSegmenterOptions proto based on the Java class.
+ImageSegmenterOptions ConvertToProtoOptions(JNIEnv* env,
+ jstring display_names_locale,
+ jint output_type,
+ jint num_threads) {
+ ImageSegmenterOptions proto_options;
+
+ const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr);
+ proto_options.set_display_names_locale(pchars);
+ env->ReleaseStringUTFChars(display_names_locale, pchars);
+
+ switch (output_type) {
+ case kOutputTypeCategoryMask:
+ proto_options.set_output_type(ImageSegmenterOptions::CATEGORY_MASK);
+ break;
+ case kOutputTypeConfidenceMask:
+ proto_options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK);
+ break;
+ default:
+ // Should never happen.
+ ThrowException(env, kIllegalArgumentException,
+ "Unsupported output type: %d", output_type);
+ }
+
+ proto_options.set_num_threads(num_threads);
+
+ return proto_options;
+}
+
+void ConvertToSegmentationResults(JNIEnv* env,
+ const SegmentationResult& results,
+ jobject jmask_buffers, jintArray jmask_shape,
+ jobject jcolored_labels) {
+ if (results.segmentation_size() != 1) {
+ // Should never happen.
+ ThrowException(
+ env, kAssertionError,
+ "ImageSegmenter only supports one segmentation result, getting %d",
+ results.segmentation_size());
+ }
+
+ const Segmentation& segmentation = results.segmentation(0);
+
+ // Get the shape from the C++ Segmentation results.
+ int shape_array[2] = {segmentation.height(), segmentation.width()};
+ env->SetIntArrayRegion(jmask_shape, 0, 2, shape_array);
+
+ // jclass, init, and add of ArrayList.
+ jclass array_list_class = env->FindClass(kArrayListClassNameNoSig);
+ jmethodID array_list_add_method =
+ env->GetMethodID(array_list_class, "add",
+ absl::StrCat("(", kObjectClassName, ")Z").c_str());
+
+ // Convert the masks into ByteBuffer list.
+ int num_pixels = segmentation.height() * segmentation.width();
+ if (segmentation.has_category_mask()) {
+ jbyteArray byte_array = CreateByteArray(
+ env,
+ reinterpret_cast<const jbyte*>(segmentation.category_mask().data()),
+ num_pixels * sizeof(uint8));
+ env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array);
+ env->DeleteLocalRef(byte_array);
+ } else {
+ for (const auto& confidence_mask :
+ segmentation.confidence_masks().confidence_mask()) {
+ jbyteArray byte_array = CreateByteArray(
+ env, reinterpret_cast<const jbyte*>(confidence_mask.value().data()),
+ num_pixels * sizeof(float));
+ env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array);
+ env->DeleteLocalRef(byte_array);
+ }
+ }
+
+ // Convert colored labels from the C++ object to the Java object.
+ jclass color_class = env->FindClass(kColorClassNameNoSig);
+ jmethodID color_rgb_method =
+ env->GetStaticMethodID(color_class, "rgb", "(III)I");
+ jclass colored_label_class = env->FindClass(kColoredLabelClassNameNoSig);
+ jmethodID colored_label_create_method = env->GetStaticMethodID(
+ colored_label_class, "create",
+ absl::StrCat("(", kStringClassName, kStringClassName, "I)",
+ kColoredLabelClassName)
+ .c_str());
+
+ for (const auto& colored_label : segmentation.colored_labels()) {
+ jstring label = env->NewStringUTF(colored_label.class_name().c_str());
+ jstring display_name =
+ env->NewStringUTF(colored_label.display_name().c_str());
+ jint rgb = env->CallStaticIntMethod(color_class, color_rgb_method,
+ colored_label.r(), colored_label.g(),
+ colored_label.b());
+ jobject jcolored_label = env->CallStaticObjectMethod(
+ colored_label_class, colored_label_create_method, label, display_name,
+ rgb);
+ env->CallBooleanMethod(jcolored_labels, array_list_add_method,
+ jcolored_label);
+
+ env->DeleteLocalRef(label);
+ env->DeleteLocalRef(display_name);
+ env->DeleteLocalRef(jcolored_label);
+ }
+}
+
+jlong CreateImageClassifierFromOptions(JNIEnv* env,
+ const ImageSegmenterOptions& options) {
+ StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
+ ImageSegmenter::CreateFromOptions(options);
+ if (image_segmenter_or.ok()) {
+ return reinterpret_cast<jlong>(image_segmenter_or->release());
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when initializing ImageSegmenter: %s",
+ image_segmenter_or.status().message().data());
+ return kInvalidPointer;
+ }
+}
+
+extern "C" JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni(
+ JNIEnv* env, jobject thiz, jlong native_handle) {
+ delete reinterpret_cast<ImageSegmenter*>(native_handle);
+}
+
+// Creates an ImageSegmenter instance from the model file descriptor.
+// file_descriptor_length and file_descriptor_offset are optional. Non-possitive
+// values will be ignored.
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions(
+ JNIEnv* env, jclass thiz, jint file_descriptor,
+ jlong file_descriptor_length, jlong file_descriptor_offset,
+ jstring display_names_locale, jint output_type, jint num_threads) {
+ ImageSegmenterOptions proto_options = ConvertToProtoOptions(
+ env, display_names_locale, output_type, num_threads);
+ auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata()
+ ->mutable_file_descriptor_meta();
+ file_descriptor_meta->set_fd(file_descriptor);
+ if (file_descriptor_length > 0) {
+ file_descriptor_meta->set_length(file_descriptor_length);
+ }
+ if (file_descriptor_offset > 0) {
+ file_descriptor_meta->set_offset(file_descriptor_offset);
+ }
+ return CreateImageClassifierFromOptions(env, proto_options);
+}
+
+extern "C" JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer(
+ JNIEnv* env, jclass thiz, jobject model_buffer,
+ jstring display_names_locale, jint output_type, jint num_threads) {
+ ImageSegmenterOptions proto_options = ConvertToProtoOptions(
+ env, display_names_locale, output_type, num_threads);
+ proto_options.mutable_model_file_with_metadata()->set_file_content(
+ static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
+ static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
+ return CreateImageClassifierFromOptions(env, proto_options);
+}
+
+extern "C" JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative(
+ JNIEnv* env, jclass thiz, jlong native_handle, jobject jimage_byte_buffer,
+ jint width, jint height, jobject jmask_buffers, jintArray jmask_shape,
+ jobject jcolored_labels, jint jorientation) {
+ auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle);
+ absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer);
+ std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer(
+ reinterpret_cast<const uint8*>(image.data()),
+ FrameBuffer::Dimension{width, height},
+ ConvertToFrameBufferOrientation(env, jorientation));
+ auto results_or = segmenter->Segment(*frame_buffer);
+ if (results_or.ok()) {
+ ConvertToSegmentationResults(env, results_or.value(), jmask_buffers,
+ jmask_shape, jcolored_labels);
+ } else {
+ ThrowException(env, kAssertionError,
+ "Error occurred when segmenting the image: %s",
+ results_or.status().message().data());
+ }
+}
+
+} // namespace
diff --git a/tensorflow_lite_support/metadata/BUILD b/tensorflow_lite_support/metadata/BUILD
new file mode 100644
index 00000000..db69bd25
--- /dev/null
+++ b/tensorflow_lite_support/metadata/BUILD
@@ -0,0 +1,51 @@
+load("@flatbuffers//:build_defs.bzl", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library", "flatbuffer_py_library")
+
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+exports_files(["metadata_schema.fbs"])
+
+flatbuffer_py_library(
+ name = "schema_py",
+ srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"],
+)
+
+# Generic schema for inference on device.
+flatbuffer_android_library(
+ name = "schema_fbs_android",
+ srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"],
+ custom_package = "org.tensorflow.lite.schema",
+)
+
+flatbuffer_java_library(
+ name = "schema_fbs_java",
+ srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"],
+ custom_package = "org.tensorflow.lite.schema",
+)
+
+# Generic schema for model metadata.
+flatbuffer_cc_library(
+ name = "metadata_schema_cc",
+ srcs = ["metadata_schema.fbs"],
+)
+
+flatbuffer_py_library(
+ name = "metadata_schema_py",
+ srcs = ["metadata_schema.fbs"],
+)
+
+flatbuffer_java_library(
+ name = "metadata_schema_java",
+ srcs = ["metadata_schema.fbs"],
+ custom_package = "org.tensorflow.lite.support.metadata.schema",
+)
+
+flatbuffer_android_library(
+ name = "metadata_schema_fbs_android",
+ srcs = ["metadata_schema.fbs"],
+ custom_package = "org.tensorflow.lite.support.metadata.schema",
+)
diff --git a/tensorflow_lite_support/metadata/README.md b/tensorflow_lite_support/metadata/README.md
new file mode 100644
index 00000000..ff7d25f2
--- /dev/null
+++ b/tensorflow_lite_support/metadata/README.md
@@ -0,0 +1,15 @@
+# TensorFlow Lite Metadata and Android wrapper code generator
+
+Note: Both TensorFlow Lite Metadata and the Android wrapper code generator are
+in experimental (beta) phase.
+
+TensorFlow Lite metadata provides a structured framework for storing metadata
+to convey information for both the developer that will utilitised the model and
+code generators which can create wrapper around the model. For information on
+how to populate model metadata, please refer to the [TensorFlow Lite Metadata
+documentation](https://www.tensorflow.org/lite/convert/metadata).
+
+The first code generator which takes advantage of this metadata format is the
+TensorFlow Lite Android Code Generator. For more information on how to use this
+generator, please refer to the [TensorFlow Lite Android wrapper code generator
+documentation](https://www.tensorflow.org/lite/guide/codegen).
diff --git a/tensorflow_lite_support/metadata/build_defs.bzl b/tensorflow_lite_support/metadata/build_defs.bzl
new file mode 100644
index 00000000..8bdab125
--- /dev/null
+++ b/tensorflow_lite_support/metadata/build_defs.bzl
@@ -0,0 +1,43 @@
+"""Build rules to generate metadata schema versions."""
+
+METADATA_SCHEMA_FILE = "//tensorflow_lite_support/metadata:metadata_schema.fbs"
+
+def stamp_metadata_parser_version(
+ name,
+ srcs,
+ outs):
+ """Stamps the latest metadata parser version into the srcs files.
+
+ Replaces all the occurrences of "{LATEST_METADATA_PARSER_VERSION}" in the
+ srcs files with the metadata schema version extracted from
+ METADATA_SCHEMA_FILE and then outputs the generated file into outs,
+ respectively. The number of srcs files needs to match the number of outs
+ files.
+
+ Args:
+ name: Rule name. (required)
+ srcs: List of source files. (required)
+ outs: List of output files. (required)
+ """
+ if len(srcs) != len(outs):
+ fail(("The number of srcs files (%d) does not match that of the outs" +
+ " files (%d).") %
+ (len(srcs), len(outs)))
+
+ for i in range(0, len(srcs)):
+ native.genrule(
+ name = "%s_file%d" % (name, i),
+ srcs = [srcs[i]],
+ outs = [outs[i]],
+ tools = [METADATA_SCHEMA_FILE],
+ # Gets the metadata schema version from the file, and stamps it
+ # into the srcs file.
+ cmd = "version=$$(sed -n -e '/Schema Semantic version/ s/.*\\: *//p' $(location %s));" %
+ METADATA_SCHEMA_FILE +
+ 'sed "s/{LATEST_METADATA_PARSER_VERSION}/$$version/" $< > $@',
+ )
+
+ native.filegroup(
+ name = name,
+ srcs = outs,
+ )
diff --git a/tensorflow_lite_support/metadata/cc/BUILD b/tensorflow_lite_support/metadata/cc/BUILD
new file mode 100644
index 00000000..ed5bedc0
--- /dev/null
+++ b/tensorflow_lite_support/metadata/cc/BUILD
@@ -0,0 +1,53 @@
+load("//tensorflow_lite_support/metadata:build_defs.bzl", "stamp_metadata_parser_version")
+
+package(
+ default_visibility = ["//tensorflow_lite_support:users"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+stamp_metadata_parser_version(
+ name = "metadata_parser_h",
+ srcs = ["metadata_parser.h.template"],
+ outs = ["metadata_parser.h"],
+)
+
+cc_library(
+ name = "metadata_extractor",
+ srcs = ["metadata_extractor.cc"],
+ hdrs = ["metadata_extractor.h"],
+ deps = [
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@flatbuffers",
+ "@org_libzip//:zip",
+ ] + select({
+ "//tensorflow_lite_support/cc:tflite_use_c_api": ["@org_tensorflow//tensorflow/lite/c:c_api"],
+ "//conditions:default": ["@org_tensorflow//tensorflow/lite:framework"],
+ }) + [
+ "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
+ "//tensorflow_lite_support/cc:common",
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/port:statusor",
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ ],
+)
+
+cc_library(
+ name = "metadata_version",
+ srcs = ["metadata_version.cc"],
+ hdrs = [
+ "metadata_version.h",
+ ":metadata_parser_h",
+ ],
+ deps = [
+ "//tensorflow_lite_support/metadata:metadata_schema_cc",
+ "@com_google_absl//absl/strings",
+ "@flatbuffers",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@org_tensorflow//tensorflow/lite/kernels/internal:compatibility",
+ "@org_tensorflow//tensorflow/lite/tools:logging",
+ ],
+)
diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
new file mode 100644
index 00000000..cf4edaa7
--- /dev/null
+++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
@@ -0,0 +1,366 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+
+#include <functional>
+
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "lib/zip.h" // from @org_libzip
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+#if TFLITE_USE_C_API
+#include "tensorflow/lite/c/c_api.h"
+#else
+#include "tensorflow/lite/model_builder.h"
+#endif
+
+namespace tflite {
+namespace metadata {
+
+namespace {
+constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
+
+using ::absl::StatusCode;
+using ::flatbuffers::Offset;
+using ::flatbuffers::Vector;
+using ::tflite::TensorMetadata;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::TfLiteSupportStatus;
+
+// Helper class that takes a callback function, and invokes it in its
+// destructor.
+class SimpleCleanUp {
+ public:
+ explicit SimpleCleanUp(std::function<void()> callback)
+ : callback_(std::move(callback)) {}
+
+ ~SimpleCleanUp() {
+ if (callback_ != nullptr) callback_();
+ }
+
+ // Use `std::move(simple_cleanup).Cancel()` to prevent the callback from ever
+ // executing at all. Once a SimpleCleanUp object has been `std::move(...)`-ed,
+ // it may not be read from again.
+ void Cancel() && { callback_ = nullptr; }
+
+ private:
+ std::function<void()> callback_;
+};
+
+// Util to get item from src_vector specified by index.
+template <typename T>
+const T* GetItemFromVector(
+ const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
+ if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
+ return nullptr;
+ }
+ return src_vector->Get(index);
+}
+} // namespace
+
+/* static */
+tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
+ModelMetadataExtractor::CreateFromModelBuffer(const char* buffer_data,
+ size_t buffer_size) {
+ // Use absl::WrapUnique() to call private constructor:
+ // https://abseil.io/tips/126.
+ std::unique_ptr<ModelMetadataExtractor> extractor =
+ absl::WrapUnique(new ModelMetadataExtractor());
+ RETURN_IF_ERROR(extractor->InitFromModelBuffer(buffer_data, buffer_size));
+ return extractor;
+}
+
+/* static */
+tflite::support::StatusOr<const tflite::ProcessUnit*>
+ModelMetadataExtractor::FindFirstProcessUnit(
+ const tflite::TensorMetadata& tensor_metadata,
+ tflite::ProcessUnitOptions type) {
+ const tflite::ProcessUnit* result = nullptr;
+ if (tensor_metadata.process_units() == nullptr) {
+ return result;
+ }
+ for (const auto process_unit : *tensor_metadata.process_units()) {
+ if (process_unit->options_type() == type) {
+ if (result != nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrCat("Found multiple ProcessUnits with type=",
+ tflite::EnumNameProcessUnitOptions(type),
+ ", expected at most one."),
+ TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
+ }
+ result = process_unit;
+ }
+ }
+ return result;
+}
+
+/* static */
+std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
+ const tflite::TensorMetadata& tensor_metadata,
+ tflite::AssociatedFileType type, absl::string_view locale) {
+ if (tensor_metadata.associated_files() == nullptr) {
+ return std::string();
+ }
+ for (const auto associated_file : *tensor_metadata.associated_files()) {
+ if (associated_file->type() != type || associated_file->name() == nullptr) {
+ continue;
+ }
+ if (locale.empty() || (associated_file->locale() != nullptr &&
+ locale == associated_file->locale()->str())) {
+ return associated_file->name()->str();
+ }
+ }
+ return std::string();
+}
+
+absl::Status ModelMetadataExtractor::InitFromModelBuffer(
+ const char* buffer_data, size_t buffer_size) {
+ // Rely on the simplest, base flatbuffers verifier. Here is not the place to
+ // e.g. use an OpResolver: we just want to make sure the buffer is valid to
+ // access the metadata.
+ flatbuffers::Verifier verifier = flatbuffers::Verifier(
+ reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
+ if (!tflite::VerifyModelBuffer(verifier)) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "The model is not a valid FlatBuffer buffer.",
+ TfLiteSupportStatus::kInvalidFlatBufferError);
+ }
+ model_ = tflite::GetModel(buffer_data);
+ if (model_->metadata() == nullptr) {
+ // Not all models have metadata, which is OK. `GetModelMetadata()` then
+ // returns nullptr.
+ return absl::OkStatus();
+ }
+ // Look for the "TFLITE_METADATA" field, if any.
+ for (int i = 0; i < model_->metadata()->size(); ++i) {
+ const auto metadata = model_->metadata()->Get(i);
+ if (metadata->name()->str() != kMetadataBufferName) {
+ continue;
+ }
+ const auto buffer_index = metadata->buffer();
+ const auto metadata_buffer =
+ model_->buffers()->Get(buffer_index)->data()->data();
+ if (!tflite::ModelMetadataBufferHasIdentifier(metadata_buffer)) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Invalid metadata schema version: expected %s, got %s",
+ absl::string_view(tflite::ModelMetadataIdentifier())
+ .substr(
+ 0, flatbuffers::FlatBufferBuilder::kFileIdentifierLength),
+ // Returned identifier is not null terminated; has to be
+ // truncated.
+ absl::string_view(
+ flatbuffers::GetBufferIdentifier(metadata_buffer))
+ .substr(
+ 0,
+ flatbuffers::FlatBufferBuilder::kFileIdentifierLength)),
+ TfLiteSupportStatus::kMetadataInvalidSchemaVersionError);
+ }
+ model_metadata_ = tflite::GetModelMetadata(metadata_buffer);
+ if (model_metadata_ == nullptr) {
+ return CreateStatusWithPayload(StatusCode::kInternal,
+ "Expected Model Metadata not to be null.");
+ }
+ return ExtractAssociatedFiles(buffer_data, buffer_size);
+ break;
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
+ const char* buffer_data, size_t buffer_size) {
+ // Setup libzip error reporting.
+ zip_error_t error;
+ zip_error_init(&error);
+ auto zip_error_cleanup = SimpleCleanUp([&error] { zip_error_fini(&error); });
+
+ // Initialize zip source.
+ zip_source_t* src =
+ zip_source_buffer_create(buffer_data, buffer_size, /*freep=*/0, &error);
+ if (src == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown,
+ absl::StrFormat("Can't create zip source from model buffer: %s",
+ zip_error_strerror(&error)),
+ TfLiteSupportStatus::kMetadataAssociatedFileZipError);
+ }
+ auto zip_source_cleanup = SimpleCleanUp([src] { zip_source_free(src); });
+
+ // Try opening zip source.
+ zip* zip_archive = zip_open_from_source(src, /*flags=*/0, &error);
+ if (zip_archive == nullptr) {
+ // It's OK if it fails: this means there are no associated files with this
+ // model.
+ return absl::OkStatus();
+ }
+ auto zip_archive_cleanup =
+ SimpleCleanUp([zip_archive] { zip_close(zip_archive); });
+ // As per the documentation [1] for zip_source_free, it should not be called
+ // after a successful call to zip_open_from_source.
+ //
+ // [1]: https://libzip.org/documentation/zip_source_free.html
+ std::move(zip_source_cleanup).Cancel();
+
+ const int num_files = zip_get_num_entries(zip_archive, /*flags=*/0);
+ for (int index = 0; index < num_files; ++index) {
+ // Get file stats.
+ struct zip_stat zip_file_stat;
+ zip_stat_init(&zip_file_stat);
+ zip_stat_index(zip_archive, index, /*flags=*/0, &zip_file_stat);
+ absl::string_view filename = zip_file_stat.name;
+ const auto unzip_filesize = zip_file_stat.size;
+
+ // Open file.
+ zip_file* zip_file = zip_fopen_index(zip_archive, index, /*flags=*/0);
+ if (zip_file == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown,
+ absl::StrFormat("Unable to open associated file with name: %s",
+ zip_file_stat.name),
+ TfLiteSupportStatus::kMetadataAssociatedFileZipError);
+ }
+ auto zip_file_cleanup = SimpleCleanUp([zip_file] { zip_fclose(zip_file); });
+
+ // Unzip file.
+ char* unzip_buffer = new char[unzip_filesize];
+ auto unzip_buffer_cleanup =
+ SimpleCleanUp([unzip_buffer] { delete[] unzip_buffer; });
+ if (zip_fread(zip_file, unzip_buffer, unzip_filesize) != unzip_filesize) {
+ return CreateStatusWithPayload(
+ StatusCode::kUnknown,
+ absl::StrFormat("Unzipping failed for file: %s.", filename),
+ TfLiteSupportStatus::kMetadataAssociatedFileZipError);
+ }
+
+ // Copy file contents in map.
+ associated_files_[filename] = std::string(unzip_buffer, unzip_filesize);
+ }
+ return absl::OkStatus();
+}
+
+tflite::support::StatusOr<absl::string_view>
+ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const {
+ auto it = associated_files_.find(filename);
+ if (it == associated_files_.end()) {
+ return CreateStatusWithPayload(
+ StatusCode::kNotFound,
+ absl::StrFormat("No associated file with name: %s", filename),
+ TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
+ }
+ return it->second;
+}
+
+const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
+ModelMetadataExtractor::GetInputTensorMetadata() const {
+ if (model_metadata_ == nullptr ||
+ model_metadata_->subgraph_metadata() == nullptr) {
+ return nullptr;
+ }
+ return model_metadata_->subgraph_metadata()
+ ->Get(kDefaultSubgraphIndex)
+ ->input_tensor_metadata();
+}
+
+const tflite::TensorMetadata* ModelMetadataExtractor::GetInputTensorMetadata(
+ int index) const {
+ return GetItemFromVector<tflite::TensorMetadata>(GetInputTensorMetadata(),
+ index);
+}
+
+int ModelMetadataExtractor::GetInputTensorCount() const {
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
+ input_tensor_metadata = GetInputTensorMetadata();
+ return input_tensor_metadata == nullptr ? 0 : input_tensor_metadata->size();
+}
+
+const Vector<Offset<TensorMetadata>>*
+ModelMetadataExtractor::GetOutputTensorMetadata() const {
+ if (model_metadata_ == nullptr ||
+ model_metadata_->subgraph_metadata() == nullptr) {
+ return nullptr;
+ }
+ return model_metadata_->subgraph_metadata()
+ ->Get(kDefaultSubgraphIndex)
+ ->output_tensor_metadata();
+}
+
+const tflite::TensorMetadata* ModelMetadataExtractor::GetOutputTensorMetadata(
+ int index) const {
+ return GetItemFromVector<tflite::TensorMetadata>(GetOutputTensorMetadata(),
+ index);
+}
+
+int ModelMetadataExtractor::GetOutputTensorCount() const {
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
+ output_tensor_metadata = GetOutputTensorMetadata();
+ return output_tensor_metadata == nullptr ? 0 : output_tensor_metadata->size();
+}
+
+const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
+ModelMetadataExtractor::GetInputProcessUnits() const {
+ if (model_metadata_ == nullptr ||
+ model_metadata_->subgraph_metadata() == nullptr) {
+ return nullptr;
+ }
+ return model_metadata_->subgraph_metadata()
+ ->Get(kDefaultSubgraphIndex)
+ ->input_process_units();
+}
+
+const tflite::ProcessUnit* ModelMetadataExtractor::GetInputProcessUnit(
+ int index) const {
+ return GetItemFromVector<tflite::ProcessUnit>(GetInputProcessUnits(), index);
+}
+
+int ModelMetadataExtractor::GetInputProcessUnitsCount() const {
+ const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* input_process_units =
+ GetInputProcessUnits();
+ return input_process_units == nullptr ? 0 : input_process_units->size();
+}
+
+const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
+ModelMetadataExtractor::GetOutputProcessUnits() const {
+ if (model_metadata_ == nullptr ||
+ model_metadata_->subgraph_metadata() == nullptr) {
+ return nullptr;
+ }
+ return model_metadata_->subgraph_metadata()
+ ->Get(kDefaultSubgraphIndex)
+ ->output_process_units();
+}
+
+const tflite::ProcessUnit* ModelMetadataExtractor::GetOutputProcessUnit(
+ int index) const {
+ return GetItemFromVector<tflite::ProcessUnit>(GetOutputProcessUnits(), index);
+}
+
+int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
+ const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* output_process_units =
+ GetOutputProcessUnits();
+ return output_process_units == nullptr ? 0 : output_process_units->size();
+}
+
+} // namespace metadata
+} // namespace tflite
diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/tensorflow_lite_support/metadata/cc/metadata_extractor.h
new file mode 100644
index 00000000..8eafe932
--- /dev/null
+++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.h
@@ -0,0 +1,157 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
+#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace metadata {
+
+// Extracts and provides easy access to the TFLite ModelMetadata [1] and
+// corresponding associated files packed into a TFLite FlatBuffer, if any.
+//
+// [1]: https://www.tensorflow.org/lite/convert/metadata
+class ModelMetadataExtractor {
+ public:
+ // Creates a ModelMetadataExtractor from the provided TFLite Model FlatBuffer
+ // and returns a pointer to the new object. Ownership is transferred to the
+ // caller. Returns an error if the creation failed, which may happen e.g. if
+ // the provided buffer is not a valid TFLite FlatBuffer.
+ //
+ // Warning: Does not take ownership of the provided buffer, which must outlive
+ // this object.
+ //
+ // It is recommended to obtain and manage the buffer through an
+ // ExternalFileHandler[1], which is optimized through mmap(2) to avoid having
+ // to load the entire buffer in memory when provided by path or file
+ // descriptor.
+ //
+ // [1]:
+ // tensorflow_lite_support/c/task/core/external_file_handler.h
+ static tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
+ CreateFromModelBuffer(const char* buffer_data, size_t buffer_size);
+
+ // Returns the pointer to the *first* ProcessUnit with the provided type, or
+ // nullptr if none can be found. An error is returned if multiple
+ // ProcessUnit-s with the provided type are found.
+ static tflite::support::StatusOr<const tflite::ProcessUnit*>
+ FindFirstProcessUnit(const tflite::TensorMetadata& tensor_metadata,
+ tflite::ProcessUnitOptions type);
+
+ // Returns the name of the *first* associated file with the provided type and
+ // (optional) locale in the provided TensorMetadata, or an empty string if no
+ // such associated file can be found (which is not necessarily an error: some
+ // models have no associated files at all) or its `name` field is unspecified.
+ // Note: see `GetAssociatedFile` to read the actual file contents.
+ static std::string FindFirstAssociatedFileName(
+ const tflite::TensorMetadata& tensor_metadata,
+ tflite::AssociatedFileType type,
+ absl::string_view locale = absl::string_view());
+
+ // Returns a pointer to the extracted TFLite Model Metadata, or nullptr if no
+ // metadata was present in the Model FlatBuffer provided at creation time.
+ const tflite::ModelMetadata* GetModelMetadata() const {
+ return model_metadata_;
+ }
+
+ // Gets the contents of the associated file with the provided name packed into
+ // the model metadata. An error is returned if there is no such associated
+ // file.
+ tflite::support::StatusOr<absl::string_view> GetAssociatedFile(
+ const std::string& filename) const;
+
+ // Note: all methods below retrieves metadata of the *first* subgraph as
+ // default.
+
+ // Gets the metadata for input tensors.
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
+ GetInputTensorMetadata() const;
+
+ // Gets the metadata for the input tensor specified by the given index, or
+ // nullptr in case there is no metadata or the index is out of range.
+ const tflite::TensorMetadata* GetInputTensorMetadata(int index) const;
+
+ // Gets the count of input tensors with metadata in the metadata FlatBuffer.
+ // In particular, 0 is returned when there is no metadata.
+ int GetInputTensorCount() const;
+
+ // Gets the metadata for output tensors.
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
+ GetOutputTensorMetadata() const;
+
+ // Gets the metadata for the output tensor specified by the given index, or
+ // nullptr in case there is no metadata or the index is out of range.
+ const tflite::TensorMetadata* GetOutputTensorMetadata(int index) const;
+
+ // Gets the count of output tensors with metadata in the metadata FlatBuffer.
+ // In particular, 0 is returned when there is no metadata.
+ int GetOutputTensorCount() const;
+
+ // Gets the input process units from SubgraphMetadata.input_process_units,
+ // could be nullptr.
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
+ GetInputProcessUnits() const;
+
+ // Gets the input process unit specified by the given index, or nullptr in
+ // case there is no input process unit or the index is out of range.
+ const tflite::ProcessUnit* GetInputProcessUnit(int index) const;
+
+ // Gets the count of input process units. In particular, 0 is returned when
+ // there is no input process units.
+ int GetInputProcessUnitsCount() const;
+
+ // Gets the output process units from SubgraphMetadata.output_process_units,
+ // could be nullptr.
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
+ GetOutputProcessUnits() const;
+
+ // Gets the output process unit specified by the given index, or nullptr in
+ // case there is no output process unit or the index is out of range.
+ const tflite::ProcessUnit* GetOutputProcessUnit(int index) const;
+
+ // Gets the count of output process units. In particular, 0 is returned when
+ // there is no output process units.
+ int GetOutputProcessUnitsCount() const;
+
+ private:
+ static constexpr int kDefaultSubgraphIndex = 0;
+ // Private default constructor, called from CreateFromModel().
+ ModelMetadataExtractor() = default;
+ // Initializes the ModelMetadataExtractor from the provided Model FlatBuffer.
+ absl::Status InitFromModelBuffer(const char* buffer_data, size_t buffer_size);
+ // Extracts and stores in associated_files_ the associated files (if present)
+ // packed into the model FlatBuffer data.
+ absl::Status ExtractAssociatedFiles(const char* buffer_data,
+ size_t buffer_size);
+ // Pointer to the TFLite Model object from which to read the ModelMetadata.
+ const tflite::Model* model_{nullptr};
+ // Pointer to the extracted ModelMetadata, if any.
+ const tflite::ModelMetadata* model_metadata_{nullptr};
+ // The files associated with the ModelMetadata, as a map with the filename
+ // (corresponding to a basename, e.g. "labels.txt") as key and the file
+ // contents as value.
+ absl::flat_hash_map<std::string, std::string> associated_files_;
+};
+
+} // namespace metadata
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
diff --git a/tensorflow_lite_support/metadata/cc/metadata_parser.h.template b/tensorflow_lite_support/metadata/cc/metadata_parser.h.template
new file mode 100644
index 00000000..7e260508
--- /dev/null
+++ b/tensorflow_lite_support/metadata/cc/metadata_parser.h.template
@@ -0,0 +1,28 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_PARSER_H_
+#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_PARSER_H_
+
+namespace tflite {
+namespace metadata {
+
+// The version of the metadata parser that this metadata versioning library is
+// depending on.
+inline constexpr char kMatadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}";
+
+} // namespace metadata
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_PARSER_H_
diff --git a/tensorflow_lite_support/metadata/cc/metadata_version.cc b/tensorflow_lite_support/metadata/cc/metadata_version.cc
new file mode 100644
index 00000000..7679f6c4
--- /dev/null
+++ b/tensorflow_lite_support/metadata/cc/metadata_version.cc
@@ -0,0 +1,302 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_lite_support/metadata/cc/metadata_version.h"
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <array>
+#include <ostream>
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/tools/logging.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace metadata {
+namespace {
+
+// Members that are added to the metadata schema after the initial version
+// of 1.0.0.
+enum class SchemaMembers {
+ kAssociatedFileTypeVocabulary = 0,
+ kSubGraphMetadataInputProcessUnits = 1,
+ kSubGraphMetadataOutputProcessUnits = 2,
+ kProcessUnitOptionsBertTokenizerOptions = 3,
+ kProcessUnitOptionsSentencePieceTokenizerOptions = 4,
+ kSubGraphMetadataInputTensorGroups = 5,
+ kSubGraphMetadataOutputTensorGroups = 6,
+ kProcessUnitOptionsRegexTokenizerOptions = 7,
+};
+
+// Helper class to compare semantic versions in terms of three integers, major,
+// minor, and patch.
+class Version {
+ public:
+ explicit Version(int major, int minor = 0, int patch = 0)
+ : version_({major, minor, patch}) {}
+
+ explicit Version(const std::string& version) {
+ const std::vector<std::string> vec = absl::StrSplit(version, '.');
+ // The version string should always be less than four numbers.
+ TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty());
+ version_[0] = std::stoi(vec[0]);
+ version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0;
+ version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0;
+ }
+
+ // Compares two semantic version numbers.
+ //
+ // Example results when comparing two versions strings:
+ // "1.9" precedes "1.14";
+ // "1.14" precedes "1.14.1";
+ // "1.14" and "1.14.0" are equal.
+ //
+ // Returns the value 0 if the two versions are equal; a value less than 0 if
+ // *this precedes v; a value greater than 0 if v precedes *this.
+ int Compare(const Version& v) {
+ for (int i = 0; i < kElementNumber; ++i) {
+ if (version_[i] != v.version_[i]) {
+ return version_[i] < v.version_[i] ? -1 : 1;
+ }
+ }
+ return 0;
+ }
+
+ // Converts version_ into a version string.
+ std::string ToString() { return absl::StrJoin(version_, "."); }
+
+ private:
+ static constexpr int kElementNumber = 3;
+ std::array<int, kElementNumber> version_;
+};
+
+Version GetMemberVersion(SchemaMembers member) {
+ switch (member) {
+ case SchemaMembers::kAssociatedFileTypeVocabulary:
+ return Version(1, 0, 1);
+ case SchemaMembers::kSubGraphMetadataInputProcessUnits:
+ return Version(1, 1, 0);
+ case SchemaMembers::kSubGraphMetadataOutputProcessUnits:
+ return Version(1, 1, 0);
+ case SchemaMembers::kProcessUnitOptionsBertTokenizerOptions:
+ return Version(1, 1, 0);
+ case SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions:
+ return Version(1, 1, 0);
+ case SchemaMembers::kSubGraphMetadataInputTensorGroups:
+ return Version(1, 2, 0);
+ case SchemaMembers::kSubGraphMetadataOutputTensorGroups:
+ return Version(1, 2, 0);
+ case SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions:
+ return Version(1, 2, 1);
+ default:
+ // Should never happen.
+ TFLITE_LOG(FATAL) << "Unsupported schema member: "
+ << static_cast<int>(member);
+ }
+ // Should never happen.
+ return Version(0, 0, 0);
+}
+
+// Updates min_version if it precedes the new_version.
+inline void UpdateMinimumVersion(const Version& new_version,
+ Version* min_version) {
+ if (min_version->Compare(new_version) < 0) {
+ *min_version = new_version;
+ }
+}
+
+template <typename T>
+void UpdateMinimumVersionForTable(const T* table, Version* min_version);
+
+template <typename T>
+void UpdateMinimumVersionForArray(
+ const flatbuffers::Vector<flatbuffers::Offset<T>>* array,
+ Version* min_version) {
+ if (array == nullptr) return;
+
+ for (int i = 0; i < array->size(); ++i) {
+ UpdateMinimumVersionForTable<T>(array->Get(i), min_version);
+ }
+}
+
+template <>
+void UpdateMinimumVersionForTable<tflite::AssociatedFile>(
+ const tflite::AssociatedFile* table, Version* min_version) {
+ if (table == nullptr) return;
+
+ if (table->type() == AssociatedFileType_VOCABULARY) {
+ UpdateMinimumVersion(
+ GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary),
+ min_version);
+ }
+}
+
+template <>
+void UpdateMinimumVersionForTable<tflite::ProcessUnit>(
+ const tflite::ProcessUnit* table, Version* min_version) {
+ if (table == nullptr) return;
+
+ tflite::ProcessUnitOptions process_unit_type = table->options_type();
+ if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) {
+ UpdateMinimumVersion(
+ GetMemberVersion(
+ SchemaMembers::kProcessUnitOptionsBertTokenizerOptions),
+ min_version);
+ }
+ if (process_unit_type == ProcessUnitOptions_SentencePieceTokenizerOptions) {
+ UpdateMinimumVersion(
+ GetMemberVersion(
+ SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions),
+ min_version);
+ }
+ if (process_unit_type == ProcessUnitOptions_RegexTokenizerOptions) {
+ UpdateMinimumVersion(
+ GetMemberVersion(
+ SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions),
+ min_version);
+ }
+}
+
+template <>
+void UpdateMinimumVersionForTable<tflite::TensorMetadata>(
+ const tflite::TensorMetadata* table, Version* min_version) {
+ if (table == nullptr) return;
+
+ // Checks the associated_files field.
+ UpdateMinimumVersionForArray<tflite::AssociatedFile>(
+ table->associated_files(), min_version);
+
+ // Checks the process_units field.
+ UpdateMinimumVersionForArray<tflite::ProcessUnit>(table->process_units(),
+ min_version);
+}
+
+template <>
+void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
+ const tflite::SubGraphMetadata* table, Version* min_version) {
+ if (table == nullptr) return;
+
+ // Checks in the input/output metadata arrays.
+ UpdateMinimumVersionForArray<tflite::TensorMetadata>(
+ table->input_tensor_metadata(), min_version);
+ UpdateMinimumVersionForArray<tflite::TensorMetadata>(
+ table->output_tensor_metadata(), min_version);
+
+ // Checks the associated_files field.
+ UpdateMinimumVersionForArray<tflite::AssociatedFile>(
+ table->associated_files(), min_version);
+
+ // Checks for the input_process_units field.
+ if (table->input_process_units() != nullptr) {
+ UpdateMinimumVersion(
+ GetMemberVersion(SchemaMembers::kSubGraphMetadataInputProcessUnits),
+ min_version);
+ UpdateMinimumVersionForArray<tflite::ProcessUnit>(
+ table->input_process_units(), min_version);
+ }
+
+ // Checks for the output_process_units field.
+ if (table->output_process_units() != nullptr) {
+ UpdateMinimumVersion(
+ GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputProcessUnits),
+ min_version);
+ UpdateMinimumVersionForArray<tflite::ProcessUnit>(
+ table->output_process_units(), min_version);
+ }
+
+ // Checks for the input_tensor_groups field.
+ if (table->input_tensor_groups() != nullptr) {
+ UpdateMinimumVersion(
+ GetMemberVersion(SchemaMembers::kSubGraphMetadataInputTensorGroups),
+ min_version);
+ }
+
+ // Checks for the output_tensor_groups field.
+ if (table->output_tensor_groups() != nullptr) {
+ UpdateMinimumVersion(
+ GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups),
+ min_version);
+ }
+}
+
+template <>
+void UpdateMinimumVersionForTable<tflite::ModelMetadata>(
+ const tflite::ModelMetadata* table, Version* min_version) {
+ if (table == nullptr) {
+ // Should never happen, because VerifyModelMetadataBuffer has verified it.
+ TFLITE_LOG(FATAL) << "The ModelMetadata object is null.";
+ return;
+ }
+
+ // Checks the subgraph_metadata field.
+ if (table->subgraph_metadata() != nullptr) {
+ for (int i = 0; i < table->subgraph_metadata()->size(); ++i) {
+ UpdateMinimumVersionForTable<tflite::SubGraphMetadata>(
+ table->subgraph_metadata()->Get(i), min_version);
+ }
+ }
+
+ // Checks the associated_files field.
+ UpdateMinimumVersionForArray<tflite::AssociatedFile>(
+ table->associated_files(), min_version);
+}
+
+} // namespace
+
+TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
+ size_t buffer_size,
+ std::string* min_version_str) {
+ flatbuffers::Verifier verifier =
+ flatbuffers::Verifier(buffer_data, buffer_size);
+ if (!tflite::VerifyModelMetadataBuffer(verifier)) {
+ TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer.";
+ return kTfLiteError;
+ }
+
+ static constexpr char kDefaultVersion[] = "1.0.0";
+ Version min_version = Version(kDefaultVersion);
+
+ // Checks if any member declared after 1.0.0 (such as those in
+ // SchemaMembers) exists, and updates min_version accordingly. The minimum
+ // metadata parser version will be the largest version number of all fields
+ // that has been added to a metadata flatbuffer
+ const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data);
+
+ // All tables in the metadata schema should have their dedicated
+ // UpdateMinimumVersionForTable<Foo>() methods, respectively. We'll gradually
+ // add these methods when new fields show up in later schema versions.
+ //
+ // UpdateMinimumVersionForTable<Foo>() takes a const pointer of Foo. The
+ // pointer can be a nullptr if Foo is not populated into the corresponding
+ // table of the Flatbuffer object. In this case,
+ // UpdateMinimumVersionFor<Foo>() will be skipped. An exception is
+ // UpdateMinimumVersionForModelMetadata(), where ModelMetadata is the root
+ // table, and it won't be null.
+ UpdateMinimumVersionForTable<tflite::ModelMetadata>(model_metadata,
+ &min_version);
+
+ *min_version_str = min_version.ToString();
+ return kTfLiteOk;
+}
+
+} // namespace metadata
+} // namespace tflite
diff --git a/tensorflow_lite_support/metadata/cc/metadata_version.h b/tensorflow_lite_support/metadata/cc/metadata_version.h
new file mode 100644
index 00000000..6332aaec
--- /dev/null
+++ b/tensorflow_lite_support/metadata/cc/metadata_version.h
@@ -0,0 +1,38 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_VERSION_H_
+#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_VERSION_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <string>
+
+#include "tensorflow/lite/c/common.h"
+
+namespace tflite {
+namespace metadata {
+
+// Gets the minimum metadata parser version that can fully understand all fields
+// in a given metadata flatbuffer. TFLite Metadata follows Semantic Versioning
+// 2.0. Each release version has the form MAJOR.MINOR.PATCH.
+TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data,
+ size_t buffer_size,
+ std::string* min_version);
+
+} // namespace metadata
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_VERSION_H_
diff --git a/tensorflow_lite_support/metadata/cc/python/BUILD b/tensorflow_lite_support/metadata/cc/python/BUILD
new file mode 100644
index 00000000..34e9a4f9
--- /dev/null
+++ b/tensorflow_lite_support/metadata/cc/python/BUILD
@@ -0,0 +1,22 @@
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
+
+package(
+ default_visibility = [
+ "//tensorflow_lite_support/metadata:__subpackages__",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+pybind_extension(
+ name = "_pywrap_metadata_version",
+ srcs = [
+ "metadata_version.cc",
+ ],
+ features = ["-use_header_modules"],
+ module_name = "_pywrap_metadata_version",
+ deps = [
+ "//tensorflow_lite_support/metadata/cc:metadata_version",
+ "@org_tensorflow//tensorflow/lite/c:common",
+ "@pybind11",
+ ],
+)
diff --git a/tensorflow_lite_support/metadata/cc/python/metadata_version.cc b/tensorflow_lite_support/metadata/cc/python/metadata_version.cc
new file mode 100644
index 00000000..db3a29e5
--- /dev/null
+++ b/tensorflow_lite_support/metadata/cc/python/metadata_version.cc
@@ -0,0 +1,55 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/metadata/cc/metadata_version.h"
+
+#include "pybind11/pybind11.h"
+#include "tensorflow/lite/c/common.h"
+
+namespace tflite {
+namespace metadata {
+
+PYBIND11_MODULE(_pywrap_metadata_version, m) {
+ m.doc() = R"pbdoc(
+ _pywrap_metadata_version
+ A module that returns the minimum metadata parser version of a given
+ metadata flatbuffer.
+ )pbdoc";
+
+ // Using pybind11 type conversions to convert between Python and native
+ // C++ types. There are other options to provide access to native Python types
+ // in C++ and vice versa. See the pybind 11 instrcution [1] for more details.
+ // Type converstions is recommended by pybind11, though the main downside
+ // is that a copy of the data must be made on every Python to C++ transition:
+ // this is needed since the C++ and Python versions of the same type generally
+ // won’t have the same memory layout.
+ //
+ // [1]: https://pybind11.readthedocs.io/en/stable/advanced/cast/index.html
+ m.def("GetMinimumMetadataParserVersion",
+ [](const std::string& buffer_data) -> std::string {
+ std::string min_version;
+ if (GetMinimumMetadataParserVersion(
+ reinterpret_cast<const uint8_t*>(buffer_data.c_str()),
+ buffer_data.length(), &min_version) != kTfLiteOk) {
+ pybind11::value_error(
+ "Error occurred when getting the minimum metadata parser "
+ "version of the metadata flatbuffer.");
+ }
+ return min_version;
+ });
+}
+
+} // namespace metadata
+} // namespace tflite
diff --git a/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD b/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD
new file mode 100644
index 00000000..d4171bf9
--- /dev/null
+++ b/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD
@@ -0,0 +1,22 @@
+load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
+
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ licenses = ["notice"], # Apache 2.0
+)
+
+pybind_extension(
+ name = "_pywrap_flatbuffers",
+ srcs = [
+ "flatbuffers_lib.cc",
+ ],
+ features = ["-use_header_modules"],
+ module_name = "_pywrap_flatbuffers",
+ deps = [
+ "@flatbuffers",
+ "@local_config_python//:python_headers",
+ "@pybind11",
+ ],
+)
diff --git a/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc b/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
new file mode 100644
index 00000000..61857225
--- /dev/null
+++ b/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc
@@ -0,0 +1,59 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "flatbuffers/idl.h" // from @flatbuffers
+#include "pybind11/pybind11.h"
+#include "pybind11/pytypes.h"
+#include "pybind11/stl.h"
+
+namespace tflite {
+namespace support {
+
+PYBIND11_MODULE(_pywrap_flatbuffers, m) {
+ pybind11::class_<flatbuffers::IDLOptions>(m, "IDLOptions")
+ .def(pybind11::init<>())
+ .def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json);
+ pybind11::class_<flatbuffers::Parser>(m, "Parser")
+ .def(pybind11::init<const flatbuffers::IDLOptions&>())
+ .def("parse",
+ [](flatbuffers::Parser* self, const std::string& source) {
+ return self->Parse(source.c_str());
+ })
+ .def_readonly("builder", &flatbuffers::Parser::builder_)
+ .def_readonly("error", &flatbuffers::Parser::error_);
+ pybind11::class_<flatbuffers::FlatBufferBuilder>(m, "FlatBufferBuilder")
+ .def("clear", &flatbuffers::FlatBufferBuilder::Clear)
+ .def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self,
+ const std::string& contents) {
+ self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
+ contents.length());
+ });
+ m.def("generate_text_file", &flatbuffers::GenerateTextFile);
+ m.def(
+ "generate_text",
+ [](const flatbuffers::Parser& parser,
+ const std::string& buffer) -> std::string {
+ std::string text;
+ if (!flatbuffers::GenerateText(
+ parser, reinterpret_cast<const void*>(buffer.c_str()), &text)) {
+ return "";
+ }
+ return text;
+ });
+}
+
+} // namespace support
+} // namespace tflite
diff --git a/tensorflow_lite_support/metadata/metadata_schema.fbs b/tensorflow_lite_support/metadata/metadata_schema.fbs
new file mode 100644
index 00000000..8faae0a8
--- /dev/null
+++ b/tensorflow_lite_support/metadata/metadata_schema.fbs
@@ -0,0 +1,686 @@
+// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace tflite;
+
+// TFLite metadata contains both human readable and machine readable information
+// about what the model does and how to use the model. It can be used as a
+// README file, which elaborates the details of the model, each input/ouput
+// tensor, and each associated file.
+//
+// An important use case of TFLite metadata is the TFLite codegen tool, which
+// automatically generates the model interface based on the properties of the
+// model and the tensors. The model interface provides high-level APIs to
+// interact with the model, such as preprocessing the input data and running
+// inferences.
+//
+// Entries marked with "<Codegen usage>" are used in TFLite codegen tool to
+// generate the model interface. It is recommended to fill in at least those
+// enties to boost the codegen performance.
+
+// The Metadata schema is versioned by the Semantic versioning number, such as
+// MAJOR.MINOR.PATCH. It tracks the schema changes according to the rules below:
+// * Bump up the MAJOR number when making potentially backwards incompatible
+// changes. It must be incremented if the new changes break the backwards
+// compatibility. It may also include minor and patch level changes as
+// needed. The true backwards compatibility is indicated by the file
+// identifier.
+// * Bump up the MINOR number when making backwards compatible updates for
+// major features, such as supporting new content types or adding new
+// processing units.
+// * Bump up the PATCH number when making small backwards compatible changes,
+// such as adding a new fields or deprecating certain fields (not deleting
+// them).
+//
+// ModelMetadata.min_parser_version indicates the minimum necessary metadata
+// parser version to fully understand all fields in a given metadata flatbuffer.
+//
+// New fields and types will have associated comments with the schema version
+// for which they were added.
+//
+// LINT.IfChange
+// Schema Semantic version: 1.2.1
+// LINT.ThenChange(//tensorflow_lite_support/\
+// metadata/java/src/java/org/tensorflow/lite/support/metadata/\
+// MetadataParser.java)
+
+// This indicates the flatbuffer compatibility. The number will bump up when a
+// break change is applied to the schema, such as removing fields or adding new
+// fields to the middle of a table.
+file_identifier "M001";
+
+// History:
+// 1.0.1 - Added VOCABULARY type to AssociatedFileType.
+// 1.1.0 - Added BertTokenizerOptions to ProcessUnitOptions.
+// Added SentencePieceTokenizerOptions to ProcessUnitOptions.
+// Added input_process_units to SubGraphMetadata.
+// Added output_process_units to SubGraphMetadata.
+// 1.2.0 - Added input_tensor_group to SubGraphMetadata.
+// Added output_tensor_group to SubGraphMetadata.
+// 1.2.1 - Added RegexTokenizerOptions to ProcessUnitOptions.
+
+// File extension of any written files.
+file_extension "tflitemeta";
+
+// LINT.IfChange
+enum AssociatedFileType : byte {
+ UNKNOWN = 0,
+
+ // Files such as readme.txt.
+ DESCRIPTIONS = 1,
+
+ // Contains labels that annotate certain axis of the tensor. For example,
+ // the label file in image classification. Those labels annotate the
+ // the output tensor, such that each value in the output tensor is the
+ // probability of that corresponding category specified by the label.
+ //
+ // <Codegen usage>:
+ // If an output tensor has an associated file as TENSOR_AXIS_LABELS, return
+ // the output as a mapping between the labels and probability in the model
+ // interface.
+ // If multiple files of the same type are present, the first one is used by
+ // default; additional ones are to be distinguished from one another by their
+ // specified locale.
+ TENSOR_AXIS_LABELS = 2,
+
+ // Contains labels that tensor values correspond to. For example, in
+ // the object detection model, one of the output tensors is the detected
+ // classes. And each value in the tensor refers to the index of label in the
+ // category label file.
+ //
+ // <Codegen usage>:
+ // If an output tensor has an associated file as TENSOR_VALUE_LABELS, convert
+ // the tensor values into labels, and return a list of string as the output.
+ // If multiple files of the same type are present, the first one is used by
+ // default; additional ones are to be distinguished from one another by their
+ // specified locale.
+ TENSOR_VALUE_LABELS = 3,
+
+ // Contains sigmoid-based score calibration parameters, formatted as CSV.
+ // Lines contain for each index of an output tensor the scale, slope, offset
+ // and (optional) min_score parameters to be used for sigmoid fitting (in this
+ // order and in `strtof`-compatible [1] format).
+ // A line may be left empty to default calibrated scores for this index to
+ // default_score.
+ // In summary, each line should thus contain 0, 3 or 4 comma-separated values.
+ //
+ // See documentation for ScoreCalibrationOptions for details.
+ //
+ // [1]: https://en.cppreference.com/w/c/string/byte/strtof
+ TENSOR_AXIS_SCORE_CALIBRATION = 4,
+
+ // Contains a list of unique words (characters separated by "\n" or in lines)
+ // that help to convert natural language words to embedding vectors.
+ // Added in: 1.0.1
+ VOCABULARY = 5,
+}
+
+table AssociatedFile {
+ // Name of this file. Need to be exact the same as the name of the actual file
+ // packed into the TFLite model as a zip file.
+ //
+ // <Codegen usage>:
+ // Locates to the actual file in the TFLite model.
+ name:string;
+
+ // A description of what the file is.
+ description:string;
+
+ // Type of the associated file. There may be special pre/post processing for
+ // some types. For example in image classification, a label file of the output
+ // will be used to convert object index into string.
+ //
+ // <Codegen usage>:
+ // Determines how to process the corresponding tensor.
+ type:AssociatedFileType;
+
+ // An optional locale for this associated file (if applicable). It is
+ // recommended to use an ISO 639-1 letter code (e.g. "en" for English),
+ // optionally completed by a two letter region code (e.g. "en-US" for US
+ // English and "en-CA" for Canadian English).
+ // Leverage this in order to specify e.g multiple label files translated in
+ // different languages.
+ locale:string;
+}
+
+// The basic content type for all tensors.
+//
+// <Codegen usage>:
+// Input feature tensors:
+// 1. Generates the method to load data from a TensorBuffer.
+// 2. Creates the preprocessing logic. The default processing pipeline is:
+// [NormalizeOp, QuantizeOp].
+// Output feature tensors:
+// 1. Generates the method to return the output data to a TensorBuffer.
+// 2. Creates the post-processing logic. The default processing pipeline is:
+// [DeQuantizeOp].
+table FeatureProperties {
+}
+
+// The type of color space of an image.
+enum ColorSpaceType : byte {
+ UNKNOWN = 0,
+ RGB = 1,
+ GRAYSCALE = 2,
+}
+
+table ImageSize {
+ width:uint;
+ height:uint;
+}
+
+// The properties for image tensors.
+//
+// <Codegen usage>:
+// Input image tensors:
+// 1. Generates the method to load an image from a TensorImage.
+// 2. Creates the preprocessing logic. The default processing pipeline is:
+// [ResizeOp, NormalizeOp, QuantizeOp].
+// Output image tensors:
+// 1. Generates the method to return the output data to a TensorImage.
+// 2. Creates the post-processing logic. The default processing pipeline is:
+// [DeQuantizeOp].
+table ImageProperties {
+ // The color space of the image.
+ //
+ // <Codegen usage>:
+ // Determines how to convert the color space of a given image from users.
+ color_space:ColorSpaceType;
+
+ // Indicates the default value of image width and height if the tensor shape
+ // is dynamic. For fixed-size tensor, this size will be consistent with the
+ // expected size.
+ default_size:ImageSize;
+}
+
+// The properties for tensors representing bounding boxes.
+//
+// <Codegen usage>:
+// Input image tensors: NA.
+// Output image tensors: parses the values into a data stucture that represents
+// bounding boxes. For example, in the generated wrapper for Android, it returns
+// the output as android.graphics.Rect objects.
+enum BoundingBoxType : byte {
+ UNKNOWN = 0,
+ // Represents the bounding box by using the combination of boundaries,
+ // {left, top, right, bottom}.
+ // The default order is {left, top, right, bottom}. Other orders can be
+ // indicated by BoundingBoxProperties.index.
+ BOUNDARIES = 1,
+
+ // Represents the bounding box by using the upper_left corner, width and
+ // height.
+ // The default order is {upper_left_x, upper_left_y, width, height}. Other
+ // orders can be indicated by BoundingBoxProperties.index.
+ UPPER_LEFT = 2,
+
+ // Represents the bounding box by using the center of the box, width and
+ // height. The default order is {center_x, center_y, width, height}. Other
+ // orders can be indicated by BoundingBoxProperties.index.
+ CENTER = 3,
+
+}
+
+enum CoordinateType : byte {
+ // The coordinates are float values from 0 to 1.
+ RATIO = 0,
+ // The coordinates are integers.
+ PIXEL = 1,
+}
+
+table BoundingBoxProperties {
+ // Denotes the order of the elements defined in each bounding box type. An
+ // empty index array represent the default order of each bounding box type.
+ // For example, to denote the default order of BOUNDARIES, {left, top, right,
+ // bottom}, the index should be {0, 1, 2, 3}. To denote the order {left,
+ // right, top, bottom}, the order should be {0, 2, 1, 3}.
+ //
+ // The index array can be applied to all bounding box types to adjust the
+ // order of their corresponding underlying elements.
+ //
+ // <Codegen usage>:
+ // Indicates how to parse the bounding box values.
+ index:[uint];
+
+ // <Codegen usage>:
+ // Indicates how to parse the bounding box values.
+ type:BoundingBoxType;
+
+ // <Codegen usage>:
+ // Indicates how to convert the bounding box back to the original image in
+ // pixels.
+ coordinate_type:CoordinateType;
+}
+
+union ContentProperties {
+ FeatureProperties,
+ ImageProperties,
+ BoundingBoxProperties,
+}
+
+table ValueRange {
+ min:int;
+ max:int;
+}
+
+table Content {
+ // The properties that the content may have, indicating the type of the
+ // Content.
+ //
+ // <Codegen usage>:
+ // Indicates how to process the tensor.
+ content_properties:ContentProperties;
+
+ // The range of dimensions that the content corresponds to. A NULL
+ // "range" indicates that the content uses up all dimensions,
+ // except the batch axis if applied.
+ //
+ // Here are all the possible situations of how a tensor is composed.
+ // Case 1: The tensor is a single object, such as an image.
+ // For example, the input of an image classifier
+ // (https://www.tensorflow.org/lite/models/image_classification/overview),
+ // a tensor of shape [1, 224, 224, 3]. Dimensions 1 to 3 correspond to the
+ // image. Since dimension 0 is a batch axis, which can be ignored,
+ // "range" can be left as NULL.
+ //
+ // Case 2: The tensor contains multiple instances of the same object.
+ // For example, the output tensor of detected bounding boxes of an object
+ // detection model
+ // (https://www.tensorflow.org/lite/models/object_detection/overview).
+ // The tensor shape is [1, 10, 4]. Here is the what the three dimensions
+ // represent for:
+ // dimension 0: the batch axis.
+ // dimension 1: the 10 objects detected with the highest confidence.
+ // dimension 2: the bounding boxes of the 10 detected objects.
+ // The tensor is essentially 10 bounding boxes. In this case,
+ // "range" should be {min=2; max=2;}.
+ //
+ // The output tensor of scores of the above object detection model has shape
+ // [1, 10], where
+ // dimension 0: the batch axis;
+ // dimension 1: the scores of the 10 detected objects.
+ // Set "range" to the number of dimensions which is {min=2; max=2;} to denote
+ // that every element in the tensor is an individual content object, i.e. a
+ // score in this example.
+ //
+ // Another example is the pose estimation model
+ // (https://www.tensorflow.org/lite/models/pose_estimation/overview).
+ // The output tensor of heatmaps is in the shape of [1, 9, 9, 17].
+ // Here is the what the four dimensions represent for:
+ // dimension 0: the batch axis.
+ // dimension 1/2: the heatmap image.
+ // dimension 3: 17 body parts of a person.
+ // Even though the last axis is body part, the real content of this tensor is
+ // the heatmap. "range" should be [min=1; max=2].
+ //
+ // Case 3: The tensor contains multiple different objects. (Not supported by
+ // Content at this point).
+ // Sometimes a tensor may contain multiple different objects, thus different
+ // contents. It is very common for regression models. For example, a model
+ // to predict the fuel efficiency
+ // (https://www.tensorflow.org/tutorials/keras/regression).
+ // The input tensor has shape [1, 9], consisting of 9 features, such as
+ // "Cylinders", "Displacement", "Weight", etc. In this case, dimension 1
+ // contains 9 different contents. However, since these sub-dimension objects
+ // barely need to be specifically processed, their contents are not recorded
+ // in the metadata. Through, the name of each dimension can be set through
+ // TensorMetadata.dimension_names.
+ //
+ // Note that if it is not case 3, a tensor can only have one content type.
+ //
+ // <Codegen usage>:
+ // Case 1: return a processed single object of certain content type.
+ // Case 2: return a list of processed objects of certain content type. The
+ // generated model interface have API to random access those objects from
+ // the output.
+ range:ValueRange;
+}
+
+// Parameters that are used when normalizing the tensor.
+table NormalizationOptions{
+ // mean and std are normalization parameters. Tensor values are normalized
+ // on a per-channel basis, by the formula
+ // (x - mean) / std.
+ // If there is only one value in mean or std, we'll propogate the value to
+ // all channels.
+ //
+ // Quantized models share the same normalization parameters as their
+ // corresponding float models. For example, an image input tensor may have
+ // the normalization parameter of
+ // mean = 127.5f and std = 127.5f.
+ // The image value will be normalized from [0, 255] to [-1, 1].
+ // Then, for quantized models, the image data should be further quantized
+ // according to the quantization parameters. In the case of uint8, the image
+ // data will be scaled back to [0, 255], while for int8, the image data will
+ // be scaled to [-128, 127].
+ //
+ // Both the normalization parameters and quantization parameters can be
+ // retrieved through the metadata extractor library.
+ // TODO(b/156644598): add link for the metadata extractor library.
+
+ // Per-channel mean of the possible values used in normalization.
+ //
+ // <Codegen usage>:
+ // Apply normalization to input tensors accordingly.
+ mean:[float];
+
+ // Per-channel standard dev. of the possible values used in normalization.
+ //
+ // <Codegen usage>:
+ // Apply normalization to input tensors accordingly.
+ std:[float];
+}
+
+// The different possible score transforms to apply to uncalibrated scores
+// before applying score calibration.
+enum ScoreTransformationType : byte {
+ // Identity function: g(x) = x.
+ IDENTITY = 0,
+ // Log function: g(x) = log(x).
+ LOG = 1,
+ // Inverse logistic function: g(x) = log(x) - log(1-x).
+ INVERSE_LOGISTIC = 2,
+}
+
+// Options to perform score calibration on an output tensor through sigmoid
+// functions. One of the main purposes of score calibration is to make scores
+// across classes comparable, so that a common threshold can be used for all
+// output classes. This is meant for models producing class predictions as
+// output, e.g. image classification or detection models.
+//
+// For each index in the output tensor, this applies:
+// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score` or if no
+// `min_score` has been specified,
+// * `f(x) = default_score` otherwise or if no scale, slope and offset have been
+// specified.
+// Where:
+// * scale, slope, offset and (optional) min_score are index-specific parameters
+// * g(x) is an index-independent transform among those defined in
+// ScoreTransformationType
+// * default_score is an index-independent parameter.
+// An AssociatedFile with type TANSOR_AXIS_SCORE_CALIBRATION specifying the
+// index-specific parameters must be associated with the corresponding
+// TensorMetadata for score calibration be applied.
+table ScoreCalibrationOptions {
+ // The function to use for transforming the uncalibrated score before
+ // applying score calibration.
+ score_transformation:ScoreTransformationType;
+
+ // The default calibrated score to apply if the uncalibrated score is
+ // below min_score or if no parameters were specified for a given index.
+ default_score:float;
+}
+
+// Performs thresholding on output tensor values, in order to filter out
+// low-confidence results.
+table ScoreThresholdingOptions {
+ // The recommended global threshold below which results are considered
+ // low-confidence and should be filtered out.
+ global_score_threshold:float;
+}
+
+// Performs Bert tokenization as in tf.text.BertTokenizer
+// (https://github.com/tensorflow/text/blob/3599f6fcd2b780a2dc413b90fb9315464f10b314/docs/api_docs/python/text/BertTokenizer.md)
+// Added in: 1.1.0
+table BertTokenizerOptions {
+ // The vocabulary files used in the BertTokenizer.
+ vocab_file:[AssociatedFile];
+}
+
+// Performs SentencePiece tokenization as in tf.text.SentencepieceTokenizer
+// (https://github.com/tensorflow/text/blob/3599f6fcd2b780a2dc413b90fb9315464f10b314/docs/api_docs/python/text/SentencepieceTokenizer.md).
+// Added in: 1.1.0
+table SentencePieceTokenizerOptions {
+ // The SentencePiece model files used in the SentencePieceTokenizer.
+ sentencePiece_model:[AssociatedFile];
+
+ // The optional vocabulary model files used in the SentencePieceTokenizer.
+ vocab_file:[AssociatedFile];
+}
+
+// Splits strings by the occurrences of delim_regex_pattern and converts the
+// tokens into ids. For example, given
+// delim_regex_pattern: "\W+",
+// string: "Words, words, words.",
+// the tokens after split are: "Words", "words", "words", "".
+// And then the tokens can be converted into ids according to the vocab_file.
+// Added in: 1.2.1
+table RegexTokenizerOptions {
+ delim_regex_pattern:string;
+ // The vocabulary files used to convert this tokens into ids.
+ vocab_file:[AssociatedFile];
+}
+
+// Options that are used when processing the tensor.
+union ProcessUnitOptions {
+ NormalizationOptions,
+ ScoreCalibrationOptions,
+ ScoreThresholdingOptions,
+ // Added in: 1.1.0
+ BertTokenizerOptions,
+ // Added in: 1.1.0
+ SentencePieceTokenizerOptions,
+ // Added in: 1.2.1
+ RegexTokenizerOptions
+}
+
+// A process unit that is used to process the tensor out-of-graph.
+table ProcessUnit {
+ options:ProcessUnitOptions;
+}
+
+
+// Statistics to describe a tensor.
+table Stats {
+ // Max and min are not currently used in tflite.support codegen. They mainly
+ // serve as references for users to better understand the model. They can also
+ // be used to validate model pre/post processing results.
+ // If there is only one value in max or min, we'll propogate the value to
+ // all channels.
+
+ // Per-channel maximum value of the tensor.
+ max:[float];
+
+ // Per-channel minimum value of the tensor.
+ min:[float];
+}
+
+// Metadata of a group of tensors. It may contain several tensors that will be
+// grouped together in codegen. For example, the TFLite object detection model
+// example (https://www.tensorflow.org/lite/models/object_detection/overview)
+// has four outputs: classes, scores, bounding boxes, and number of detections.
+// If the four outputs are bundled together using TensorGroup (for example,
+// named as "detection result"), the codegen tool will generate the class,
+// `DetectionResult`, which contains the class, score, and bouding box. And the
+// outputs of the model will be converted to a list of `DetectionResults` and
+// the number of detection. Note that the number of detection is a single
+// number, therefore is inappropriate for the list of `DetectionResult`.
+// Added in: 1.2.0
+table TensorGroup {
+ // Name of tensor group.
+ //
+ // <codegen usage>:
+ // Name of the joint class of the tensor group.
+ name:string;
+
+ // Names of the tensors to group together, corresponding to
+ // TensorMetadata.name.
+ //
+ // <codegen usage>:
+ // Determines which tensors will be added to this group. All tensors in the
+ // group should have the same number of elements specified by Content.range.
+ tensor_names:[string];
+}
+
+// Detailed information of an input or output tensor.
+table TensorMetadata {
+ // Name of the tensor.
+ //
+ // <Codegen usage>:
+ // The name of this tensor in the generated model interface.
+ name:string;
+
+ // A description of the tensor.
+ description:string;
+
+ // A list of names of the dimensions in this tensor. The length of
+ // dimension_names need to match the number of dimensions in this tensor.
+ //
+ // <Codegen usage>:
+ // The name of each dimension in the generated model interface. See "Case 2"
+ // in the comments of Content.range.
+ dimension_names:[string];
+
+ // The content that represents this tensor.
+ //
+ // <Codegen usage>:
+ // Determines how to process this tensor. See each item in ContentProperties
+ // for the default process units that will be applied to the tensor.
+ content:Content;
+
+ // The process units that are used to process the tensor out-of-graph.
+ //
+ // <Codegen usage>:
+ // Contains the parameters of the default processing pipeline for each content
+ // type, such as the normalization parameters in all content types. See the
+ // items under ContentProperties for the details of the default processing
+ // pipeline.
+ process_units:[ProcessUnit];
+
+ // The statistics of the tensor values.
+ stats:Stats;
+
+ // A list of associated files of this tensor.
+ //
+ // <Codegen usage>:
+ // Contains processing parameters of this tensor, such as normalization.
+ associated_files:[AssociatedFile];
+}
+
+table SubGraphMetadata {
+ // Name of the subgraph.
+ //
+ // Note that, since TFLite only support one subgraph at this moment, the
+ // Codegen tool will use the name in ModelMetadata in the generated model
+ // interface.
+ name:string;
+
+ // A description explains details about what the subgraph does.
+ description:string;
+
+ // Metadata of all input tensors used in this subgraph. It matches extactly
+ // the input tensors specified by `SubGraph.inputs` in the TFLite
+ // schema.fbs file[2]. The number of `TensorMetadata` in the array should
+ // equal to the number of indices in `SubGraph.inputs`.
+ //
+ // [2]: tensorflow/lite/schema/schema.fbs
+ // <Codegen usage>:
+ // Determines how to process the inputs.
+ input_tensor_metadata:[TensorMetadata];
+
+ // Metadata of all output tensors used in this subgraph. It matches extactly
+ // the output tensors specified by `SubGraph.outputs` in the TFLite
+ // schema.fbs file[2]. The number of `TensorMetadata` in the array should
+ // equal to the number of indices in `SubGraph.outputs`.
+ //
+ // <Codegen usage>:
+ // Determines how to process the outputs.
+ output_tensor_metadata:[TensorMetadata];
+
+ // A list of associated files of this subgraph.
+ associated_files:[AssociatedFile];
+
+ // Input process units of the subgraph. Some models may have complex pre and
+ // post processing logics where the process units do not work on one tensor at
+ // a time, but in a similar way of a TFLite graph. For example, in the
+ // MobileBert model (https://www.tensorflow.org/lite/models/bert_qa/overview),
+ // the inputs are: ids / mask / segment ids;
+ // the outputs are: end logits / start logits.
+ // The preprocessing converts the query string and the context string to the
+ // model inputs, and the post-processing converts the model outputs to the
+ // answer string.
+ // Added in: 1.1.0
+ input_process_units:[ProcessUnit];
+
+ // Output process units of the subgraph.
+ // Added in: 1.1.0
+ output_process_units:[ProcessUnit];
+
+ // Metadata of all input tensor groups used in this subgraph.
+ //
+ // <codegen usage>:
+ // Bundles the corresponding elements of the underlying input tensors together
+ // into a class, and converts those individual tensors into a list of the
+ // class objects.
+ // Added in: 1.2.0
+ input_tensor_groups:[TensorGroup];
+
+ // Metadata of all output tensor groups used in this subgraph.
+ //
+ // <codegen usage>:
+ // Bundles the corresponding elements of the underlying output tensors
+ // together into a class, and converts those individual tensors into a list of
+ // the class objects.
+ // Added in: 1.2.0
+ output_tensor_groups:[TensorGroup];
+
+}
+
+table ModelMetadata {
+ // Name of the model.
+ //
+ // <Codegen usage>:
+ // The name of the model in the generated model interface.
+ name:string;
+
+ // Model description in schema.
+ description:string;
+
+ // Version of the model that specified by model creators.
+ version:string;
+
+ // Noted that, the minimum required TFLite runtime version that the model is
+ // compatible with, has already been added as a metadata entry in tflite
+ // schema. We'll decide later if we want to move it here, and keep it with
+ // other metadata entries.
+
+ // Metadata of all the subgraphs of the model. The 0th is assumed to be the
+ // main subgraph.
+ //
+ // <Codegen usage>:
+ // Determines how to process the inputs and outputs.
+ subgraph_metadata:[SubGraphMetadata];
+
+ // The person who creates this model.
+ author:string;
+
+ // Licenses that may apply to this model.
+ license:string;
+
+ // A list of associated files of this model.
+ associated_files:[AssociatedFile];
+
+ // The minimum metadata parser version that can fully understand the fields in
+ // the metadata flatbuffer. The version is effectively the largest version
+ // number among the versions of all the fields populated and the smallest
+ // compatible version indicated by the file identifier.
+ //
+ // This field is automaticaly populated by the MetadataPopulator when
+ // the metadata is populated into a TFLite model.
+ min_parser_version:string;
+}
+// LINT.ThenChange(//tensorflow_lite_support/\
+// metadata/cc/metadata_version.cc)
+
+root_type ModelMetadata;
diff --git a/tensorflow_lite_support/opensource/opensource_only.files b/tensorflow_lite_support/opensource/opensource_only.files
new file mode 100644
index 00000000..be426420
--- /dev/null
+++ b/tensorflow_lite_support/opensource/opensource_only.files
@@ -0,0 +1,36 @@
+tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl
+tensorflow_lite_support/opensource/BUILD
+tensorflow_lite_support/opensource/WORKSPACE
+tensorflow_lite_support/opensource/cc_build_defs.bzl
+tensorflow_lite_support/third_party/android/BUILD
+tensorflow_lite_support/third_party/android/android.bzl.tpl
+tensorflow_lite_support/third_party/android/android_configure.BUILD.tpl
+tensorflow_lite_support/third_party/android/android_configure.bzl
+tensorflow_lite_support/third_party/com_google_absl.BUILD
+tensorflow_lite_support/third_party/darts_clone.BUILD
+tensorflow_lite_support/third_party/fft2d/BUILD
+tensorflow_lite_support/third_party/fft2d/LICENSE
+tensorflow_lite_support/third_party/fft2d/fft.h
+tensorflow_lite_support/third_party/fft2d/fft2d.BUILD
+tensorflow_lite_support/third_party/fft2d/fft2d.h
+tensorflow_lite_support/third_party/google_toolbox_for_mac.BUILD
+tensorflow_lite_support/third_party/icu.BUILD
+tensorflow_lite_support/third_party/libyuv.BUILD
+tensorflow_lite_support/third_party/libzip.BUILD
+tensorflow_lite_support/third_party/pybind11.BUILD
+tensorflow_lite_support/third_party/python_runtime/BUILD
+tensorflow_lite_support/third_party/six.BUILD
+tensorflow_lite_support/third_party/stblib.BUILD
+tensorflow_lite_support/third_party/toolchains/java/BUILD
+tensorflow_lite_support/third_party/utf.BUILD
+tensorflow_lite_support/third_party/zlib.BUILD
+tensorflow_lite_support/tools/ci_build/build_all.sh
+tensorflow_lite_support/tools/ci_build/common.sh
+tensorflow_lite_support/tools/ci_build/common_win.bat
+tensorflow_lite_support/tools/pip_package/BUILD
+tensorflow_lite_support/tools/pip_package/MANIFEST.in
+tensorflow_lite_support/tools/pip_package/README
+tensorflow_lite_support/tools/pip_package/build_pip_package.sh
+tensorflow_lite_support/tools/pip_package/setup.py
+tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py
+tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py \ No newline at end of file
diff --git a/tensorflow_lite_support/tools/BUILD b/tensorflow_lite_support/tools/BUILD
new file mode 100644
index 00000000..c3525ca4
--- /dev/null
+++ b/tensorflow_lite_support/tools/BUILD
@@ -0,0 +1,20 @@
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+py_binary(
+ name = "zip_files",
+ srcs = ["zip_files.py"],
+ python_version = "PY3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "@absl_py//absl:app",
+ "@absl_py//absl/flags",
+ ],
+)
+
+py_library(
+ name = "expect_flatbuffers_installed",
+ srcs = [],
+)
diff --git a/tensorflow_lite_support/tools/build_rules/expand_template.bzl b/tensorflow_lite_support/tools/build_rules/expand_template.bzl
new file mode 100644
index 00000000..717860ca
--- /dev/null
+++ b/tensorflow_lite_support/tools/build_rules/expand_template.bzl
@@ -0,0 +1,50 @@
+"""Build macro for libzip."""
+
+# forked from kythe/kythe/tools/build_rules/expand_template.bzl
+def _expand_template_impl(ctx):
+ ctx.actions.expand_template(
+ template = ctx.file.template,
+ output = ctx.outputs.out,
+ substitutions = ctx.attr.substitutions,
+ )
+
+expand_template = rule(
+ attrs = {
+ "out": attr.output(mandatory = True),
+ "substitutions": attr.string_dict(mandatory = True),
+ "template": attr.label(
+ mandatory = True,
+ allow_single_file = True,
+ ),
+ },
+ output_to_genfiles = True,
+ implementation = _expand_template_impl,
+)
+
+def cmake_substitutions(vars, defines = {}):
+ """Returns a dict of template substitutions combining `vars` and `defines`.
+
+ Args:
+ vars: will be turned into a dict replacing `${key}` and `@key@` with `value`.
+ defines: will be turned into a dict replacing `#cmakedefine` with `#define {value}`
+ if present is true, otherwise `/* #undef %s /*`.
+ Returns:
+ substitutions
+ """
+ subs = {}
+ for key, value in vars.items():
+ subs["${%s}" % (key,)] = str(value) if value != None else ""
+ subs["@%s@" % (key,)] = str(value) if value != None else ""
+
+ # TODO(shahms): Better handling of #cmakedefine delimiters and line endings to
+ # avoid the prefix-substitution problem.
+ # Potentially allow value to be: True, False, None or string.
+ # True/False => Same as current
+ # None => assume no suffix value, include \n in sub and replacement
+ # string => use string to lookup in vars and assume ${} or @@ tail?
+ for macro, present in defines.items():
+ if present:
+ subs["#cmakedefine %s" % macro] = "#define %s" % macro
+ else:
+ subs["#cmakedefine %s" % macro] = "/* #undef %s */" % macro
+ return subs
diff --git a/tensorflow_lite_support/tools/ci_build/build_all.sh b/tensorflow_lite_support/tools/ci_build/build_all.sh
new file mode 100644
index 00000000..7e98b6c9
--- /dev/null
+++ b/tensorflow_lite_support/tools/ci_build/build_all.sh
@@ -0,0 +1,42 @@
+#!/usr/bin/env bash
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# External `build_all.sh`
+
+set -ex
+
+bazel build -c opt --config=monolithic \
+ //tensorflow_lite_support/java:tensorflowlite_support \
+ //tensorflow_lite_support/codegen/python:codegen \
+ //tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata_lib \
+ //tensorflow_lite_support/metadata/cc:metadata_extractor \
+ //tensorflow_lite_support/custom_ops/kernel:all \
+ //tensorflow_lite_support/custom_ops/python:tflite_text_api
+
+# Build Task libraries.
+bazel build -c opt --config=monolithic \
+ --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
+ //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base-task-api.aar \
+ //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text:task-library-text \
+ //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision:task-library-vision
+
+
+# Run Metadata tests.
+bazel clean --expunge
+
+bazel test --test_output=all \
+ //tensorflow_lite_support/metadata/python/tests:metadata_test \
+ //tensorflow_lite_support/metadata/python/tests/metadata_writers:all
+
diff --git a/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh b/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh
new file mode 100644
index 00000000..a36f97d3
--- /dev/null
+++ b/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh
@@ -0,0 +1,114 @@
+#!/bin/bash
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Pip install TensorFlow Lite Support and run basic test on the pip package.
+
+# Important: Use msys shell to run this script on Windows.
+
+set -e
+set -x
+
+function run_smoke_test() {
+ VENV_TMP_DIR="$(mktemp -d)"
+
+ if [[ "$OSTYPE" == "msys" ]]; then
+ VENV_TMP_DIR="$(cygpath -m $VENV_TMP_DIR)"
+ fi
+
+ ${PYTHON_BIN_PATH} -m virtualenv -p ${PYTHON_BIN_PATH} "${VENV_TMP_DIR}" || \
+ die "FAILED: Unable to create virtualenv"
+
+ if [[ "$OSTYPE" == "msys" ]]; then
+ source "${VENV_TMP_DIR}/Scripts/activate" || \
+ die "FAILED: Unable to activate virtualenv "
+ else
+ source "${VENV_TMP_DIR}/bin/activate" || \
+ die "FAILED: Unable to activate virtualenv "
+ fi
+
+ # install tflite-support
+ python -m pip install ${WHL_NAME} || \
+ die "pip install (forcing to reinstall tflite-support) FAILED"
+ echo "Successfully installed pip package ${WHL_NAME}"
+
+ # Download a test model
+ export TEST_MODEL="$(pwd)/test.tflite"
+ wget https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_0.75_192_quantized/1/metadata/1\?lite-format\=tflite -O "$TEST_MODEL"
+ if [[ "$OSTYPE" == "msys" ]]; then
+ TEST_MODEL=$(cygpath -m $TEST_MODEL)
+ fi
+
+ test_tfls_imports
+
+ test_codegen
+
+ # Deactivate from virtualenv.
+ deactivate || source deactivate || \
+ die "FAILED: Unable to deactivate from existing virtualenv."
+
+ echo "All smoke test passes!"
+}
+
+function test_tfls_imports() {
+ TMP_DIR=$(mktemp -d)
+ pushd "${TMP_DIR}"
+
+ # test for basic import and metadata display.
+ RET_VAL=$(python -c "from tflite_support import metadata; \
+md = metadata.MetadataDisplayer.with_model_file(\"$TEST_MODEL\"); \
+print(md.get_metadata_json())")
+
+ # just check if the model name is there.
+ if ! [[ ${RET_VAL} == *"MobileNetV1 image classifier (quantized)"* ]]; then
+ echo "Unexpected return value: ${RET_VAL}"
+ echo "PIP smoke test on virtualenv FAILED, do not upload ${WHL_NAME}."
+ return 1
+ fi
+
+ RESULT=$?
+
+ popd
+ return $RESULT
+}
+
+function test_codegen() {
+ TMP_DIR=$(mktemp -d)
+ pushd "${TMP_DIR}"
+
+ # test for basic import and metadata display.
+ tflite_codegen --model ${TEST_MODEL} --destination tmp
+ RESULT=$?
+
+ # just check if the model name is there.
+ if [[ ${RESULT} -ne 0 ]]; then
+ echo "Unexpected return value: ${RESULT}"
+ echo "PIP smoke test on virtualenv FAILED, do not upload ${WHL_NAME}."
+ return 1
+ fi
+
+ popd
+ return $RESULT
+}
+
+###########################################################################
+# Main
+###########################################################################
+if [[ -z "${1}" ]]; then
+ echo "TFLite Support WHL path not given, unable to install and test."
+ return 1
+fi
+
+WHL_NAME=${1}
+run_smoke_test
diff --git a/tensorflow_lite_support/tools/ci_build/common.sh b/tensorflow_lite_support/tools/ci_build/common.sh
new file mode 100644
index 00000000..4907bb1b
--- /dev/null
+++ b/tensorflow_lite_support/tools/ci_build/common.sh
@@ -0,0 +1,96 @@
+#!/usr/bin/env bash
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# External `common.sh`
+
+# Keep in sync with tensorflow core and configure.py.
+# TODO(b/158448780): Guard bazel version with IfChangeThenChange.
+LATEST_BAZEL_VERSION=3.1.0
+
+# Run flaky functions with retries.
+# run_with_retry cmd
+function run_with_retry {
+ eval "$1"
+ # If the command fails retry again in 60 seconds.
+ if [[ $? -ne 0 ]]; then
+ sleep 60
+ eval "$1"
+ fi
+}
+
+function die() {
+ echo "$@" 1>&2 ; exit 1;
+}
+
+# A small utility to run the command and only print logs if the command fails.
+# On success, all logs are hidden.
+function readable_run {
+ # Disable debug mode to avoid printing of variables here.
+ set +x
+ result=$("$@" 2>&1) || die "$result"
+ echo "$@"
+ echo "Command completed successfully at $(date)"
+ set -x
+}
+
+# TODO(b/158448780): Guard bazel installation with IfChangeThenChange.
+function set_bazel_outdir {
+ mkdir -p /tmpfs/bazel_output
+ export TEST_TMPDIR=/tmpfs/bazel_output
+}
+
+# Downloads bazelisk to ~/bin as `bazel`.
+function install_bazelisk {
+ date
+ case "$(uname -s)" in
+ Darwin) local name=bazelisk-darwin-amd64 ;;
+ Linux) local name=bazelisk-linux-amd64 ;;
+ *) die "Unknown OS: $(uname -s)" ;;
+ esac
+ mkdir -p "$HOME/bin"
+ wget --no-verbose -O "$HOME/bin/bazel" \
+ "https://github.com/bazelbuild/bazelisk/releases/download/v1.3.0/$name"
+ chmod u+x "$HOME/bin/bazel"
+ if [[ ! ":$PATH:" =~ :"$HOME"/bin/?: ]]; then
+ PATH="$HOME/bin:$PATH"
+ fi
+ set_bazel_outdir
+ which bazel
+ bazel version
+ date
+}
+
+# Install the given bazel version on linux
+function update_bazel_linux {
+ if [[ -z "$1" ]]; then
+ BAZEL_VERSION=${LATEST_BAZEL_VERSION}
+ else
+ BAZEL_VERSION=$1
+ fi
+ rm -rf ~/bazel
+ mkdir ~/bazel
+
+ pushd ~/bazel
+ readable_run wget https://github.com/bazelbuild/bazel/releases/download/"${BAZEL_VERSION}"/bazel-"${BAZEL_VERSION}"-installer-linux-x86_64.sh
+ chmod +x bazel-*.sh
+ ./bazel-"${BAZEL_VERSION}"-installer-linux-x86_64.sh --user
+ rm bazel-"${BAZEL_VERSION}"-installer-linux-x86_64.sh
+ popd
+
+ PATH="/home/kbuilder/bin:$PATH"
+ set_bazel_outdir
+ which bazel
+ bazel version
+}
diff --git a/tensorflow_lite_support/tools/ci_build/common_win.bat b/tensorflow_lite_support/tools/ci_build/common_win.bat
new file mode 100644
index 00000000..35f39a72
--- /dev/null
+++ b/tensorflow_lite_support/tools/ci_build/common_win.bat
@@ -0,0 +1,29 @@
+:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+::
+:: Licensed under the Apache License, Version 2.0 (the "License");
+:: you may not use this file except in compliance with the License.
+:: You may obtain a copy of the License at
+::
+:: http://www.apache.org/licenses/LICENSE-2.0
+::
+:: Unless required by applicable law or agreed to in writing, software
+:: distributed under the License is distributed on an "AS IS" BASIS,
+:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+:: See the License for the specific language governing permissions and
+:: limitations under the License.
+:: =============================================================================
+
+:: This script is shamefully borrowed from:
+:: //third_party/tensorflow/tools/ci_build/release/common_win.bat.oss
+
+echo on
+
+@REM
+@REM Setup Bazel
+@REM
+:: Download Bazel from github and make sure its found in PATH.
+SET BAZEL_VERSION=3.1.0
+md C:\tools\bazel\
+wget -q https://github.com/bazelbuild/bazel/releases/download/%BAZEL_VERSION%/bazel-%BAZEL_VERSION%-windows-x86_64.exe -O C:/tools/bazel/bazel.exe
+SET PATH=C:\tools\bazel;%PATH%
+bazel version
diff --git a/tensorflow_lite_support/tools/ci_build/update_version.py b/tensorflow_lite_support/tools/ci_build/update_version.py
new file mode 100644
index 00000000..86fa588d
--- /dev/null
+++ b/tensorflow_lite_support/tools/ci_build/update_version.py
@@ -0,0 +1,120 @@
+# lint as: python3
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License..
+# ==============================================================================
+"""Update version code in the repo.
+
+We use a python script rather than GNU tools to avoid cross-platform
+difficulties.
+
+The script takes 3 argument:
+ --src <path> a path pointing to the code repo.
+ --version <version> the new version code.
+ --nightly [default: false] when true, the version code will append a build
+ suffix (e.g. dev20201103)
+
+It should not run by bazel. Use it as a simple python script.
+"""
+
+import argparse
+import datetime
+import os
+import re
+
+SETUP_PY_PATH = "tensorflow_lite_support/tools/pip_package/setup.py"
+
+
+def replace_string_in_line(search, replace, filename):
+ """Replace the string in every line of the file in-place."""
+ with open(filename, "r") as f:
+ content = f.read()
+ with open(filename, "w") as f:
+ f.write(re.sub(search, replace, content))
+
+
+def get_current_version(path):
+ """Get the current version code from setup.py."""
+ for line in open(os.path.join(path, SETUP_PY_PATH)):
+ match = re.search("^_VERSION = '([a-z0-9\\.\\-]+)'", line)
+ if match:
+ return match.group(1)
+ print("Cannot find current version!")
+ return None
+
+
+def update_version(path, current_version, new_version):
+ """Update the version code in the codebase."""
+ # Update setup.py
+ replace_string_in_line(
+ "_VERSION = '%s'" % current_version,
+ # pep440 requires such a replacement
+ "_VERSION = '%s'" % new_version.replace("-", "."),
+ os.path.join(path, SETUP_PY_PATH))
+
+
+class CustomTimeZone(datetime.tzinfo):
+
+ def utcoffset(self, dt):
+ return -datetime.timedelta(hours=8)
+
+ def tzname(self, dt):
+ return "UTC-8"
+
+ def dst(self, dt):
+ return datetime.timedelta(0)
+
+
+def remove_build_suffix(version):
+ """Remove build suffix (if exists) from a version."""
+ if version.find("-dev") >= 0:
+ return version[:version.find("-dev")]
+ if version.find(".dev") >= 0:
+ return version[:version.find(".dev")]
+ if version.find("dev") >= 0:
+ return version[:version.find("dev")]
+ return version
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Update TFLS version in repo")
+ parser.add_argument(
+ "--src",
+ help="a path pointing to the code repo",
+ required=True,
+ default="")
+ parser.add_argument("--version", help="the new SemVer code", default="")
+ parser.add_argument(
+ "--nightly",
+ help="if true, a build suffix will append to the version code. If "
+ "current version code or the <version> argument provided contains a "
+ "build suffix, the suffix will be replaced with the timestamp",
+ action="store_true")
+ args = parser.parse_args()
+
+ path = args.src
+ current_version = get_current_version(path)
+ if not current_version:
+ return
+ new_version = args.version if args.version else current_version
+ if args.nightly:
+ new_version = remove_build_suffix(new_version)
+ # Use UTC-8 rather than uncertain local time.
+ d = datetime.datetime.now(tz=CustomTimeZone())
+ new_version += "-dev" + d.strftime("%Y%m%d")
+ print("Updating version from %s to %s" % (current_version, new_version))
+ update_version(path, current_version, new_version)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tensorflow_lite_support/tools/pip_package/BUILD b/tensorflow_lite_support/tools/pip_package/BUILD
new file mode 100644
index 00000000..61df24a4
--- /dev/null
+++ b/tensorflow_lite_support/tools/pip_package/BUILD
@@ -0,0 +1,55 @@
+# Description:
+# Tools for building the TensorFlow pip package.
+
+load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
+
+package(default_visibility = ["//visibility:private"])
+
+COMMON_PIP_DEPS = [
+ ":licenses",
+ "MANIFEST.in",
+ "README",
+ "setup.py",
+ "//tensorflow_lite_support/codegen/python:codegen",
+ "//tensorflow_lite_support/metadata/python:metadata",
+]
+
+filegroup(
+ name = "licenses",
+ data = [
+ "//:LICENSE",
+ "@org_tensorflow//:LICENSE",
+ ] + if_not_system_lib(
+ "absl_py",
+ [
+ "@absl_py//absl:LICENSE",
+ "@absl_py//absl/logging:LICENSE",
+ "@absl_py//absl/flags:LICENSE",
+ "@absl_py//absl/testing:LICENSE",
+ "@absl_py//absl/third_party/unittest3_backport:LICENSE",
+ ],
+ ),
+)
+
+sh_binary(
+ name = "build_pip_package",
+ srcs = ["build_pip_package.sh"],
+ data = COMMON_PIP_DEPS +
+ select({
+ "@org_tensorflow//tensorflow:windows": [
+ ":simple_console_for_windows",
+ ],
+ "//conditions:default": [
+ ],
+ }),
+)
+
+# On Windows, python binary is a zip file of runfiles tree.
+# Add everything to its data dependency for generating a runfiles tree
+# for building the pip package on Windows.
+py_binary(
+ name = "simple_console_for_windows",
+ srcs = ["simple_console_for_windows.py"],
+ data = COMMON_PIP_DEPS,
+ srcs_version = "PY2AND3",
+)
diff --git a/tensorflow_lite_support/tools/pip_package/MANIFEST.in b/tensorflow_lite_support/tools/pip_package/MANIFEST.in
new file mode 100644
index 00000000..e44f271f
--- /dev/null
+++ b/tensorflow_lite_support/tools/pip_package/MANIFEST.in
@@ -0,0 +1,9 @@
+include LICENSE
+include README.md
+recursive-include * *.py
+recursive-include * *.pyd
+recursive-include * *.fbs
+recursive-include * *.so
+recursive-include * *.dylib
+recursive-include * *.dll
+recursive-include * *.lib
diff --git a/tensorflow_lite_support/tools/pip_package/README b/tensorflow_lite_support/tools/pip_package/README
new file mode 100644
index 00000000..1e1f9d5a
--- /dev/null
+++ b/tensorflow_lite_support/tools/pip_package/README
@@ -0,0 +1 @@
+TensorFlow Lite Support
diff --git a/tensorflow_lite_support/tools/pip_package/build_pip_package.sh b/tensorflow_lite_support/tools/pip_package/build_pip_package.sh
new file mode 100755
index 00000000..2f962e3f
--- /dev/null
+++ b/tensorflow_lite_support/tools/pip_package/build_pip_package.sh
@@ -0,0 +1,232 @@
+#!/usr/bin/env bash
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+function is_absolute {
+ [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]]
+}
+
+function real_path() {
+ is_absolute "$1" && echo "$1" || echo "$PWD/${1#./}"
+}
+
+function move_to_root_if_exists () {
+ arg_to_move="$1"
+ if [ -e "${arg_to_move}" ]; then
+ mv ${arg_to_move} ./
+ fi
+}
+
+function reorganize_includes() {
+ TMPDIR="${1%/}"
+}
+
+PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
+function is_windows() {
+ if [[ "${PLATFORM}" =~ (cygwin|mingw32|mingw64|msys)_nt* ]]; then
+ true
+ else
+ false
+ fi
+}
+
+function prepare_src() {
+ if [ $# -lt 1 ] ; then
+ echo "No destination dir provided"
+ exit 1
+ fi
+
+ TMPDIR="${1%/}"
+ mkdir -p "$TMPDIR"
+ EXTERNAL_INCLUDES="${TMPDIR}/tflite_support/include/external"
+
+ echo $(date) : "=== Preparing sources in dir: ${TMPDIR}"
+
+ if [ ! -d bazel-bin/tensorflow_lite_support ]; then
+ echo "Could not find bazel-bin. Did you run from the root of the build tree?"
+ exit 1
+ fi
+
+ if is_windows; then
+ rm -rf ./bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows_unzip
+ mkdir -p ./bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows_unzip
+ echo "Unzipping simple_console_for_windows.zip to create runfiles tree..."
+ unzip -o -q ./bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows.zip -d ./bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows_unzip
+ echo "Unzip finished."
+ # runfiles structure after unzip the python binary
+ RUNFILES=bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows_unzip/runfiles/org_tensorflow_lite_support
+
+ # TODO(b/165872313): Investigate the case and remove the hack.
+ # On Windows, __init__.py are not auto genereated at directories that only
+ # contains Pybind libraries.
+ touch "$RUNFILES/tensorflow_lite_support/metadata/cc/__init__.py"
+ touch "$RUNFILES/tensorflow_lite_support/metadata/cc/python/__init__.py"
+ touch "$RUNFILES/tensorflow_lite_support/metadata/flatbuffers_lib/__init__.py"
+ else
+ RUNFILES=bazel-bin/tensorflow_lite_support/tools/pip_package/build_pip_package.runfiles/org_tensorflow_lite_support
+ fi
+
+ cp "$RUNFILES/LICENSE" "${TMPDIR}"
+ cp -R "$RUNFILES/tensorflow_lite_support" "${TMPDIR}"
+
+ reorganize_includes "${TMPDIR}"
+
+ cp tensorflow_lite_support/tools/pip_package/MANIFEST.in ${TMPDIR}
+ cp tensorflow_lite_support/tools/pip_package/README ${TMPDIR}/README.md
+ cp tensorflow_lite_support/tools/pip_package/setup.py ${TMPDIR}
+
+ # A helper entry.
+ mkdir ${TMPDIR}/tflite_support
+ cp tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py ${TMPDIR}/tflite_support/__init__.py
+}
+
+function build_wheel() {
+ if [ $# -lt 2 ] ; then
+ echo "No src and dest dir provided"
+ exit 1
+ fi
+
+ TMPDIR="$1"
+ DEST="$2"
+ PKG_NAME_FLAG="$3"
+
+ # Before we leave the top-level directory, make sure we know how to
+ # call python.
+ if [[ -e tools/python_bin_path.sh ]]; then
+ source tools/python_bin_path.sh
+ fi
+
+ pushd ${TMPDIR} > /dev/null
+
+ rm -f MANIFEST
+ echo $(date) : "=== Building wheel"
+ "${PYTHON_BIN_PATH:-python}" setup.py bdist_wheel ${PKG_NAME_FLAG} >/dev/null
+ mkdir -p ${DEST}
+ cp dist/* ${DEST}
+ popd > /dev/null
+ echo $(date) : "=== Output wheel file is in: ${DEST}"
+}
+
+function usage() {
+ echo "Usage:"
+ echo "$0 [--src srcdir] [--dst dstdir] [options]"
+ echo "$0 dstdir [options]"
+ echo ""
+ echo " --src prepare sources in srcdir"
+ echo " will use temporary dir if not specified"
+ echo ""
+ echo " --dst build wheel in dstdir"
+ echo " if dstdir is not set do not build, only prepare sources"
+ echo ""
+ echo " Options:"
+ echo " --project_name <name> set project name to <name>"
+ echo " --version <version> reset the pip package version to <version>"
+ echo " --nightly_flag build TFLite Support nightly"
+ echo ""
+ echo "When using bazel, add the following flag: --run_under=\"cd \$PWD && \""
+ echo ""
+ exit 1
+}
+
+function main() {
+ PKG_NAME_FLAG=""
+ PROJECT_NAME=""
+ NIGHTLY_BUILD=0
+ SRCDIR=""
+ DSTDIR=""
+ CLEANSRC=1
+ VERSION=""
+ while true; do
+ if [[ "$1" == "--help" ]]; then
+ usage
+ exit 1
+ elif [[ "$1" == "--nightly_flag" ]]; then
+ NIGHTLY_BUILD=1
+ elif [[ "$1" == "--project_name" ]]; then
+ shift
+ if [[ -z "$1" ]]; then
+ break
+ fi
+ PROJECT_NAME="$1"
+ elif [[ "$1" == "--version" ]]; then
+ shift
+ if [[ -z "$1" ]]; then
+ break
+ fi
+ VERSION="$1"
+ elif [[ "$1" == "--src" ]]; then
+ shift
+ SRCDIR="$(real_path $1)"
+ CLEANSRC=0
+ elif [[ "$1" == "--dst" ]]; then
+ shift
+ DSTDIR="$(real_path $1)"
+ else
+ echo "Unrecognized flag: $1"
+ usage
+ exit 1
+ fi
+ shift
+
+ if [[ -z "$1" ]]; then
+ break
+ fi
+ done
+
+ if [[ -z "$DSTDIR" ]] && [[ -z "$SRCDIR" ]]; then
+ echo "No destination dir provided"
+ usage
+ exit 1
+ fi
+
+ if [[ -z "$SRCDIR" ]]; then
+ # make temp srcdir if none set
+ SRCDIR="$(mktemp -d -t tmp.XXXXXXXXXX)"
+ fi
+
+ if [[ -z "$DSTDIR" ]]; then
+ # only want to prepare sources
+ exit
+ fi
+
+ if [[ -n ${PROJECT_NAME} ]]; then
+ PKG_NAME_FLAG="--project_name ${PROJECT_NAME}"
+ elif [[ ${NIGHTLY_BUILD} == "1" ]]; then
+ PKG_NAME_FLAG="--project_name tflite_support_nightly"
+ fi
+
+ if [[ ${NIGHTLY_BUILD} == "1" ]]; then
+ # we use a script to update versions to avoid any tool differences on different platforms.
+ if [[ ! -z ${VERSION} ]]; then
+ python tensorflow_lite_support/tools/ci_build/update_version.py --src "." --version ${VERSION} --nightly
+ else
+ python tensorflow_lite_support/tools/ci_build/update_version.py --src "." --nightly
+ fi
+ elif [[ ! -z ${VERSION} ]]; then
+ python tensorflow_lite_support/tools/ci_build/update_version.py --src "." --version ${VERSION}
+ fi
+
+ prepare_src "$SRCDIR"
+
+ build_wheel "$SRCDIR" "$DSTDIR" "$PKG_NAME_FLAG"
+
+ if [[ $CLEANSRC -ne 0 ]]; then
+ rm -rf "${TMPDIR}"
+ fi
+}
+
+main "$@"
diff --git a/tensorflow_lite_support/tools/pip_package/setup.py b/tensorflow_lite_support/tools/pip_package/setup.py
new file mode 100644
index 00000000..460c5057
--- /dev/null
+++ b/tensorflow_lite_support/tools/pip_package/setup.py
@@ -0,0 +1,154 @@
+# lint as: python3
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License..
+# ==============================================================================
+"""TFLite Support is a toolkit that helps users to develop ML and deploy TFLite models onto mobile devices.
+
+This PyPI package includes the Python bindings for following features:
+
+ - Metadata schemas: wraps TFLite model schema and metadata schema in Python.
+ - Metadata populator and displayer: can be used to populate the metadata and
+ associated files into the model, as well as converting the populated metadata
+ into the json format.
+ - Android Codegen tool: generates the Java model interface used in Android for
+ a particular model.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import fnmatch
+import os
+import re
+import sys
+
+from setuptools import Command
+from setuptools import find_packages
+from setuptools import setup
+from setuptools.command.install import install as InstallCommandBase
+from setuptools.dist import Distribution
+
+# This version string is semver compatible, but incompatible with pip.
+# For pip, we will remove all '-' characters from this string, and use the
+# result for pip.
+_VERSION = '0.1.0'
+
+SETUP_PACKAGES = [
+ 'pybind11 >= 2.6.0',
+]
+
+REQUIRED_PACKAGES = [
+ 'absl-py >= 0.7.0',
+ 'numpy >= 1.16.0',
+ 'flatbuffers >= 1.12',
+] + SETUP_PACKAGES
+
+project_name = 'tflite-support'
+if '--project_name' in sys.argv:
+ project_name_idx = sys.argv.index('--project_name')
+ project_name = sys.argv[project_name_idx + 1]
+ sys.argv.remove('--project_name')
+ sys.argv.pop(project_name_idx)
+
+DOCLINES = __doc__.split('\n')
+
+CONSOLE_SCRIPTS = [
+ 'tflite_codegen = tensorflow_lite_support.codegen.python.codegen:main',
+]
+
+
+class BinaryDistribution(Distribution):
+
+ def has_ext_modules(self):
+ return True
+
+
+class InstallCommand(InstallCommandBase):
+ """Override the dir where the headers go."""
+
+ def finalize_options(self):
+ ret = InstallCommandBase.finalize_options(self)
+ self.install_lib = self.install_platlib
+ return ret
+
+
+def find_files(pattern, root):
+ """Return all the files matching pattern below root dir."""
+ for dirpath, _, files in os.walk(root):
+ for filename in fnmatch.filter(files, pattern):
+ yield os.path.join(dirpath, filename)
+
+
+so_lib_paths = [
+ i for i in os.listdir('.')
+ if os.path.isdir(i) and fnmatch.fnmatch(i, '_solib_*')
+]
+
+matches = []
+for path in so_lib_paths:
+ matches.extend(['../' + x for x in find_files('*', path) if '.py' not in x])
+
+EXTENSIONS = ['codegen/_pywrap_codegen.so']
+
+headers = ()
+
+setup(
+ name=project_name,
+ version=_VERSION.replace('-', ''),
+ description=DOCLINES[0],
+ long_description='\n'.join(DOCLINES),
+ long_description_content_type='text/markdown',
+ url='https://www.tensorflow.org/',
+ download_url='https://github.com/tensorflow/tflite-support/tags',
+ author='Google, LLC.',
+ author_email='packages@tensorflow.org',
+ # Contained modules and scripts.
+ packages=find_packages(),
+ entry_points={
+ 'console_scripts': CONSOLE_SCRIPTS,
+ },
+ headers=headers,
+ setup_requires=SETUP_PACKAGES,
+ install_requires=REQUIRED_PACKAGES,
+ tests_require=REQUIRED_PACKAGES,
+ # Add in any packaged data.
+ include_package_data=True,
+ package_data={
+ 'tflite-support': EXTENSIONS + matches,
+ },
+ zip_safe=False,
+ distclass=BinaryDistribution,
+ cmdclass={
+ 'install': InstallCommand,
+ },
+ # PyPI package information.
+ classifiers=sorted([
+ 'Development Status :: 3 - Alpha',
+ 'Intended Audience :: Developers',
+ 'Intended Audience :: Education',
+ 'Intended Audience :: Science/Research',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Programming Language :: Python :: 2.7',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Topic :: Scientific/Engineering',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: Software Development',
+ 'Topic :: Software Development :: Libraries',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
+ ]),
+ license='Apache 2.0',
+)
diff --git a/tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py b/tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py
new file mode 100644
index 00000000..106528bb
--- /dev/null
+++ b/tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py
@@ -0,0 +1,33 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Start a simple interactive console with TensorFlow available."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import code
+import sys
+
+
+def main(_):
+ """Run an interactive console."""
+ code.interact()
+ return 0
+
+
+if __name__ == '__main__':
+ sys.exit(main(sys.argv))
diff --git a/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py b/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py
new file mode 100644
index 00000000..c5bf3de5
--- /dev/null
+++ b/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An import entry for the TFLite Support project.
+
+In the original project structure, all python targets are accessed by paths like
+tensorflow_lite_support.metadata.metadata.MetadataDisplayer, which is verbose
+and deep. This file provides some shortcuts. It's also compatible with our first
+version Pip package.
+
+In pip build, this file will be renamed as tflite_support/__init__.py.
+"""
+
+import flatbuffers
+from tensorflow_lite_support.metadata import metadata_schema_py_generated
+from tensorflow_lite_support.metadata import schema_py_generated
+from tensorflow_lite_support.metadata.python import metadata
diff --git a/tensorflow_lite_support/tools/zip_files.py b/tensorflow_lite_support/tools/zip_files.py
new file mode 100644
index 00000000..9dc66236
--- /dev/null
+++ b/tensorflow_lite_support/tools/zip_files.py
@@ -0,0 +1,41 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Lint as: python3
+"""Creates a zip package of the files passed in."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import zipfile
+
+from absl import app
+from absl import flags
+
+FLAGS = flags.FLAGS
+flags.DEFINE_string("export_zip_path", None, "Path to zip file.")
+flags.DEFINE_string("file_directory", None, "Path to the files to be zipped.")
+
+
+def main(_):
+ with zipfile.ZipFile(FLAGS.export_zip_path, mode="w") as zf:
+ for root, _, files in os.walk(FLAGS.file_directory):
+ for f in files:
+ if f.endswith(".java"):
+ zf.write(os.path.join(root, f))
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/third_party/BUILD b/third_party/BUILD
new file mode 100644
index 00000000..fe756e1b
--- /dev/null
+++ b/third_party/BUILD
@@ -0,0 +1 @@
+licenses(["notice"]) # Apache 2.0
diff --git a/third_party/android/BUILD b/third_party/android/BUILD
new file mode 100644
index 00000000..fd69d4ba
--- /dev/null
+++ b/third_party/android/BUILD
@@ -0,0 +1 @@
+# Placeholder to make bazel treat it as a package.
diff --git a/third_party/android/android.bzl.tpl b/third_party/android/android.bzl.tpl
new file mode 100644
index 00000000..e6ed4994
--- /dev/null
+++ b/third_party/android/android.bzl.tpl
@@ -0,0 +1,9 @@
+"""Set up configurable Android SDK and NDK dependencies."""
+
+def android_workspace():
+ # String for replacement in Bazel template.
+ # These will either be replaced by android_sdk_repository if various ENV
+ # variables are set when `local_config_android` repo_rule is run, or they
+ # will be replaced by noops otherwise.
+ MAYBE_ANDROID_SDK_REPOSITORY
+ MAYBE_ANDROID_NDK_REPOSITORY
diff --git a/third_party/android/android_configure.BUILD.tpl b/third_party/android/android_configure.BUILD.tpl
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/third_party/android/android_configure.BUILD.tpl
diff --git a/third_party/android/android_configure.bzl b/third_party/android/android_configure.bzl
new file mode 100644
index 00000000..2fd2d807
--- /dev/null
+++ b/third_party/android/android_configure.bzl
@@ -0,0 +1,95 @@
+"""Repository rule for Android SDK and NDK autoconfiguration.
+
+`android_configure` depends on the following environment variables:
+
+ * `ANDROID_NDK_HOME`: Location of Android NDK root.
+ * `ANDROID_SDK_HOME`: Location of Android SDK root.
+ * `ANDROID_SDK_API_LEVEL`: Desired Android SDK API version.
+ * `ANDROID_NDK_API_LEVEL`: Desired Android NDK API version.
+ * `ANDROID_BUILD_TOOLS_VERSION`: Desired Android build tools version.
+
+
+Writes Android SDK and NDK rules.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+android_configure(name = "local_config_android")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""
+
+_ANDROID_NDK_HOME = "ANDROID_NDK_HOME"
+_ANDROID_SDK_HOME = "ANDROID_SDK_HOME"
+_ANDROID_NDK_API_VERSION = "ANDROID_NDK_API_LEVEL"
+_ANDROID_SDK_API_VERSION = "ANDROID_SDK_API_LEVEL"
+_ANDROID_BUILD_TOOLS_VERSION = "ANDROID_BUILD_TOOLS_VERSION"
+
+_ANDROID_SDK_REPO_TEMPLATE = """
+ native.android_sdk_repository(
+ name="androidsdk",
+ path="%s",
+ api_level=%s,
+ build_tools_version="%s",
+ )
+"""
+
+_ANDROID_NDK_REPO_TEMPLATE = """
+ native.android_ndk_repository(
+ name="androidndk",
+ path="%s",
+ api_level=%s,
+ )
+"""
+
+def _android_autoconf_impl(repository_ctx):
+ """Implementation of the android_autoconf repository rule."""
+ sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME)
+ sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION)
+ build_tools_version = repository_ctx.os.environ.get(
+ _ANDROID_BUILD_TOOLS_VERSION,
+ )
+ ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME)
+ ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION)
+
+ sdk_rule = ""
+ if all([sdk_home, sdk_api_level, build_tools_version]):
+ sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % (
+ sdk_home,
+ sdk_api_level,
+ build_tools_version,
+ )
+
+ ndk_rule = ""
+ if all([ndk_home, ndk_api_level]):
+ ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level)
+
+ if ndk_rule == "" and sdk_rule == "":
+ sdk_rule = "pass"
+ # TODO(xunkai): Add interactive configure script.
+
+ repository_ctx.template(
+ "BUILD",
+ Label("//third_party/android:android_configure.BUILD.tpl"),
+ )
+ repository_ctx.template(
+ "android.bzl",
+ Label("//third_party/android:android.bzl.tpl"),
+ substitutions = {
+ "MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule,
+ "MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule,
+ },
+ )
+
+android_configure = repository_rule(
+ implementation = _android_autoconf_impl,
+ environ = [
+ _ANDROID_SDK_API_VERSION,
+ _ANDROID_NDK_API_VERSION,
+ _ANDROID_BUILD_TOOLS_VERSION,
+ _ANDROID_NDK_HOME,
+ _ANDROID_SDK_HOME,
+ ],
+)
diff --git a/third_party/com_google_absl.BUILD b/third_party/com_google_absl.BUILD
new file mode 100644
index 00000000..8fca145f
--- /dev/null
+++ b/third_party/com_google_absl.BUILD
@@ -0,0 +1,5 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache
+
+exports_files(["LICENSE"])
diff --git a/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff b/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff
new file mode 100644
index 00000000..0cd2dffa
--- /dev/null
+++ b/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff
@@ -0,0 +1,14 @@
+diff --git a/absl/time/internal/cctz/BUILD.bazel b/absl/time/internal/cctz/BUILD.bazel
+index 9fceffe..e7f9d01 100644
+--- a/absl/time/internal/cctz/BUILD.bazel
++++ b/absl/time/internal/cctz/BUILD.bazel
+@@ -69,8 +69,5 @@ cc_library(
+ "include/cctz/zone_info_source.h",
+ ],
+ linkopts = select({
+- ":osx": [
+- "-framework Foundation",
+- ],
+ ":ios": [
+ "-framework Foundation",
+ ], \ No newline at end of file
diff --git a/third_party/com_google_protobuf_fixes.diff b/third_party/com_google_protobuf_fixes.diff
new file mode 100644
index 00000000..b9bc17ea
--- /dev/null
+++ b/third_party/com_google_protobuf_fixes.diff
@@ -0,0 +1,140 @@
+diff --git a/BUILD b/BUILD
+index 79871d621..51b3a063f 100644
+--- a/BUILD
++++ b/BUILD
+@@ -26,7 +26,7 @@ config_setting(
+ # ZLIB configuration
+ ################################################################################
+
+-ZLIB_DEPS = ["@zlib//:zlib"]
++ZLIB_DEPS = ["@zlib"]
+
+ ################################################################################
+ # Protobuf Runtime Library
+@@ -157,6 +157,7 @@ cc_library(
+ includes = ["src/"],
+ linkopts = LINK_OPTS,
+ visibility = ["//visibility:public"],
++ alwayslink = 1,
+ )
+
+ PROTOBUF_DEPS = select({
+@@ -230,6 +231,7 @@ cc_library(
+ linkopts = LINK_OPTS,
+ visibility = ["//visibility:public"],
+ deps = [":protobuf_lite"] + PROTOBUF_DEPS,
++ alwayslink = 1,
+ )
+
+ # This provides just the header files for use in projects that need to build
+@@ -318,13 +320,13 @@ cc_proto_library(
+
+ [native_cc_proto_library(
+ name = proto + "_cc_proto",
+- deps = [proto + "_proto"],
+ visibility = ["//visibility:private"],
++ deps = [proto + "_proto"],
+ ) for proto in WELL_KNOWN_PROTO_MAP.keys()]
+
+ cc_proto_blacklist_test(
+ name = "cc_proto_blacklist_test",
+- deps = [proto + "_cc_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()]
++ deps = [proto + "_cc_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()],
+ )
+
+ ################################################################################
+@@ -900,7 +902,6 @@ py_proto_library(
+ py_extra_srcs = glob(["python/**/__init__.py"]),
+ py_libs = [
+ ":python_srcs",
+- "@six//:six",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+@@ -1002,7 +1003,9 @@ cc_library(
+ # Note: We use `native_proto_common` here because we depend on an implementation-detail of
+ # `proto_lang_toolchain`, which may not be available on `proto_common`.
+ reject_blacklisted_files = hasattr(native_proto_common, "proto_lang_toolchain_rejects_files_do_not_use_or_we_will_break_you_without_mercy")
++
+ cc_toolchain_blacklisted_protos = [proto + "_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()] if reject_blacklisted_files else [":well_known_protos"]
++
+ proto_lang_toolchain(
+ name = "cc_toolchain",
+ blacklisted_protos = cc_toolchain_blacklisted_protos,
+diff --git a/protobuf.bzl b/protobuf.bzl
+index 829464d44..4ac23594b 100644
+--- a/protobuf.bzl
++++ b/protobuf.bzl
+@@ -87,6 +87,8 @@ def _proto_gen_impl(ctx):
+ for dep in ctx.attr.deps:
+ import_flags += dep.proto.import_flags
+ deps += dep.proto.deps
++ import_flags = depset(import_flags).to_list()
++ deps = depset(deps).to_list()
+
+ if not ctx.attr.gen_cc and not ctx.attr.gen_py and not ctx.executable.plugin:
+ return struct(
+diff --git a/src/google/protobuf/io/gzip_stream.h b/src/google/protobuf/io/gzip_stream.h
+index b1ce1d36c..d5d560ea7 100644
+--- a/src/google/protobuf/io/gzip_stream.h
++++ b/src/google/protobuf/io/gzip_stream.h
+@@ -47,10 +47,12 @@
+ #include <google/protobuf/stubs/common.h>
+ #include <google/protobuf/io/zero_copy_stream.h>
+ #include <google/protobuf/port.h>
+-#include <zlib.h>
+-
+ #include <google/protobuf/port_def.inc>
+
++#if HAVE_ZLIB
++#include <zlib.h>
++#endif // HAVE_ZLIB
++
+ namespace google {
+ namespace protobuf {
+ namespace io {
+@@ -76,8 +78,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream {
+ virtual ~GzipInputStream();
+
+ // Return last error message or NULL if no error.
++#if HAVE_ZLIB
+ inline const char* ZlibErrorMessage() const { return zcontext_.msg; }
+ inline int ZlibErrorCode() const { return zerror_; }
++#endif // HAVE_ZLIB
+
+ // implements ZeroCopyInputStream ----------------------------------
+ bool Next(const void** data, int* size);
+@@ -90,8 +94,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream {
+
+ ZeroCopyInputStream* sub_stream_;
+
++ #if HAVE_ZLIB
+ z_stream zcontext_;
+ int zerror_;
++ #endif // HAVE_ZLIB
+
+ void* output_buffer_;
+ void* output_position_;
+@@ -142,9 +148,11 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream {
+
+ virtual ~GzipOutputStream();
+
++#if HAVE_ZLIB
+ // Return last error message or NULL if no error.
+ inline const char* ZlibErrorMessage() const { return zcontext_.msg; }
+ inline int ZlibErrorCode() const { return zerror_; }
++#endif // HAVE_ZLIB
+
+ // Flushes data written so far to zipped data in the underlying stream.
+ // It is the caller's responsibility to flush the underlying stream if
+@@ -177,8 +185,10 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream {
+ void* sub_data_;
+ int sub_data_size_;
+
++#if HAVE_ZLIB
+ z_stream zcontext_;
+ int zerror_;
++#endif //HAVE_ZLIB
+ void* input_buffer_;
+ size_t input_buffer_length_;
+
diff --git a/third_party/darts_clone.BUILD b/third_party/darts_clone.BUILD
new file mode 100644
index 00000000..1d95ec2f
--- /dev/null
+++ b/third_party/darts_clone.BUILD
@@ -0,0 +1,15 @@
+# Description:
+# Darts-clone is a clone of Darts (Double-ARray Trie System).
+
+licenses(["notice"])
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "darts_clone",
+ hdrs = [
+ "include/darts.h",
+ ],
+)
diff --git a/third_party/fft2d/BUILD b/third_party/fft2d/BUILD
new file mode 100644
index 00000000..863a1cef
--- /dev/null
+++ b/third_party/fft2d/BUILD
@@ -0,0 +1,48 @@
+# Headers for 2D Fast Fourier Transform package
+# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html
+# This is a separate package because the original downloaded archive doesn't
+# contain any header files.
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+# Unrestricted use; can only distribute original package.
+# See fft/readme.txt
+licenses(["notice"])
+
+exports_files(["LICENSE"])
+
+cc_library(
+ name = "fft2d_headers",
+ srcs = [
+ "fft.h",
+ "fft2d.h",
+ ],
+)
+
+objc_library(
+ name = "fft2d_headersd_ios",
+ srcs = [
+ "fft.h",
+ "fft2d.h",
+ ],
+)
+
+# Export the source code so that it could be compiled for Andoid native apps.
+filegroup(
+ name = "fft2d_headers_srcs",
+ srcs = [
+ "fft.h",
+ "fft2d.h",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = ["**/OWNERS"],
+ ),
+ visibility = ["//third_party/tensorflow:__subpackages__"],
+)
diff --git a/third_party/fft2d/LICENSE b/third_party/fft2d/LICENSE
new file mode 100644
index 00000000..2bd85506
--- /dev/null
+++ b/third_party/fft2d/LICENSE
@@ -0,0 +1,3 @@
+Copyright(C) 1997,2001 Takuya OOURA (email: ooura@kurims.kyoto-u.ac.jp).
+You may use, copy, modify this code for any purpose and
+without fee. You may distribute this ORIGINAL package.
diff --git a/third_party/fft2d/fft.h b/third_party/fft2d/fft.h
new file mode 100644
index 00000000..36d838b7
--- /dev/null
+++ b/third_party/fft2d/fft.h
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Declarations for 1D FFT routines in third_party/fft2d/fft2d.
+
+#ifndef FFT2D_FFT_H__
+#define FFT2D_FFT_H__
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+extern void cdft(int, int, double *, int *, double *);
+extern void rdft(int, int, double *, int *, double *);
+extern void ddct(int, int, double *, int *, double *);
+extern void ddst(int, int, double *, int *, double *);
+extern void dfct(int, double *, double *, int *, double *);
+extern void dfst(int, double *, double *, int *, double *);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // FFT2D_FFT_H__
diff --git a/third_party/fft2d/fft2d.BUILD b/third_party/fft2d/fft2d.BUILD
new file mode 100644
index 00000000..9fa5097f
--- /dev/null
+++ b/third_party/fft2d/fft2d.BUILD
@@ -0,0 +1,45 @@
+# 2D Fast Fourier Transform package
+# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+# Unrestricted use; can only distribute original package.
+licenses(["notice"])
+
+exports_files(["readme2d.txt"])
+
+FFT2D_SRCS = [
+ "fftsg.c",
+ "fftsg2d.c",
+]
+
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
+
+# This is the main 2D FFT library. The 2D FFTs in this library call
+# 1D FFTs. In addition, fast DCTs are provided for the special case
+# of 8x8 and 16x16. This code in this library is referred to as
+# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html.
+cc_library(
+ name = "fft2d",
+ srcs = FFT2D_SRCS,
+ linkopts = select({
+ ":windows": [],
+ "//conditions:default": ["-lm"],
+ }),
+)
+
+objc_library(
+ name = "fft2d_ios",
+ srcs = FFT2D_SRCS,
+)
+
+# Export the source code so that it could be compiled for Andoid native apps.
+filegroup(
+ name = "fft2d_srcs",
+ srcs = FFT2D_SRCS,
+)
diff --git a/third_party/fft2d/fft2d.h b/third_party/fft2d/fft2d.h
new file mode 100644
index 00000000..d587b3b4
--- /dev/null
+++ b/third_party/fft2d/fft2d.h
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Declarations for 2D FFT routines in third_party/fft2d/fft2d.
+
+#ifndef FFT2D_FFT_H__
+#define FFT2D_FFT_H__
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+extern void cdft2d(int, int, int, double **, double *, int *, double *);
+extern void rdft2d(int, int, int, double **, double *, int *, double *);
+extern void ddct2d(int, int, int, double **, double *, int *, double *);
+extern void ddst2d(int, int, int, double **, double *, int *, double *);
+extern void ddct8x8s(int isgn, double **a);
+extern void ddct16x16s(int isgn, double **a);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // FFT2D_FFT_H__
diff --git a/third_party/flatbuffers/BUILD b/third_party/flatbuffers/BUILD
new file mode 100644
index 00000000..82bab3ff
--- /dev/null
+++ b/third_party/flatbuffers/BUILD
@@ -0,0 +1 @@
+# This empty BUILD file is required to make Bazel treat this directory as a package.
diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel
new file mode 100644
index 00000000..1ee46f05
--- /dev/null
+++ b/third_party/flatbuffers/BUILD.bazel
@@ -0,0 +1,140 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE.txt"])
+
+licenses(["notice"])
+
+config_setting(
+ name = "freebsd",
+ values = {"cpu": "freebsd"},
+)
+
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
+
+load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
+
+# Public flatc library to compile flatbuffer files at runtime.
+cc_library(
+ name = "flatbuffers",
+ hdrs = ["//:public_headers"],
+ linkstatic = 1,
+ strip_include_prefix = "/include",
+ visibility = ["//visibility:public"],
+ deps = ["//src:flatbuffers"],
+)
+
+# Public C++ headers for the Flatbuffers library.
+filegroup(
+ name = "public_headers",
+ srcs = [
+ "include/flatbuffers/base.h",
+ "include/flatbuffers/code_generators.h",
+ "include/flatbuffers/flatbuffers.h",
+ "include/flatbuffers/flexbuffers.h",
+ "include/flatbuffers/hash.h",
+ "include/flatbuffers/idl.h",
+ "include/flatbuffers/minireflect.h",
+ "include/flatbuffers/reflection.h",
+ "include/flatbuffers/reflection_generated.h",
+ "include/flatbuffers/registry.h",
+ "include/flatbuffers/stl_emulation.h",
+ "include/flatbuffers/util.h",
+ ],
+ visibility = ["//:__subpackages__"],
+)
+
+# Public flatc compiler library.
+cc_library(
+ name = "flatc_library",
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ "@flatbuffers//src:flatc_library",
+ ],
+)
+
+# Public flatc compiler.
+cc_binary(
+ name = "flatc",
+ linkopts = select({
+ ":freebsd": [
+ "-lm",
+ ],
+ ":windows": [],
+ "//conditions:default": [
+ "-lm",
+ "-ldl",
+ ],
+ }),
+ visibility = ["//visibility:public"],
+ deps = [
+ "@flatbuffers//src:flatc",
+ ],
+)
+
+filegroup(
+ name = "flatc_headers",
+ srcs = [
+ "include/flatbuffers/flatc.h",
+ ],
+ visibility = ["//:__subpackages__"],
+)
+
+# Library used by flatbuffer_cc_library rules.
+cc_library(
+ name = "runtime_cc",
+ hdrs = [
+ "include/flatbuffers/base.h",
+ "include/flatbuffers/flatbuffers.h",
+ "include/flatbuffers/flexbuffers.h",
+ "include/flatbuffers/stl_emulation.h",
+ "include/flatbuffers/util.h",
+ ],
+ linkstatic = 1,
+ strip_include_prefix = "/include",
+ visibility = ["//visibility:public"],
+)
+
+filegroup(
+ name = "runtime_py_srcs",
+ srcs = [
+ "python/flatbuffers/__init__.py",
+ "python/flatbuffers/builder.py",
+ "python/flatbuffers/compat.py",
+ "python/flatbuffers/encode.py",
+ "python/flatbuffers/number_types.py",
+ "python/flatbuffers/packer.py",
+ "python/flatbuffers/table.py",
+ "python/flatbuffers/util.py",
+ ],
+)
+
+py_library(
+ name = "runtime_py",
+ srcs = [":runtime_py_srcs"],
+ visibility = ["//visibility:public"],
+)
+
+filegroup(
+ name = "runtime_java_srcs",
+ srcs = glob(["java/com/google/flatbuffers/**/*.java"]),
+)
+
+java_library(
+ name = "runtime_java",
+ srcs = [":runtime_java_srcs"],
+ visibility = ["//visibility:public"],
+)
+
+android_library(
+ name = "runtime_android",
+ srcs = [":runtime_java_srcs"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl
new file mode 100644
index 00000000..986b8d12
--- /dev/null
+++ b/third_party/flatbuffers/build_defs.bzl
@@ -0,0 +1,617 @@
+"""BUILD rules for generating flatbuffer files."""
+
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+
+flatc_path = "@flatbuffers//:flatc"
+zip_files = "//tensorflow_lite_support/tools:zip_files"
+
+DEFAULT_INCLUDE_PATHS = [
+ "./",
+ "$(GENDIR)",
+ "$(BINDIR)",
+]
+
+DEFAULT_FLATC_ARGS = [
+ "--no-union-value-namespacing",
+ "--gen-object-api",
+]
+
+def flatbuffer_library_public(
+ name,
+ srcs,
+ outs,
+ language_flag,
+ out_prefix = "",
+ includes = [],
+ include_paths = [],
+ compatible_with = [],
+ flatc_args = DEFAULT_FLATC_ARGS,
+ reflection_name = "",
+ reflection_visibility = None,
+ output_to_bindir = False):
+ """Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler.
+
+ Outs:
+ filegroup(name): all generated source files.
+ Fileset([reflection_name]): (Optional) all generated reflection binaries.
+
+ Args:
+ name: Rule name.
+ srcs: Source .fbs files. Sent in order to the compiler.
+ outs: Output files from flatc.
+ language_flag: Target language flag. One of [-c, -j, -js].
+ out_prefix: Prepend this path to the front of all generated files except on
+ single source targets. Usually is a directory name.
+ includes: Optional, list of filegroups of schemas that the srcs depend on.
+ include_paths: Optional, list of paths the includes files can be found in.
+ compatible_with: Optional, passed to genrule for environments this rule
+ can be built for.
+ flatc_args: Optional, list of additional arguments to pass to flatc.
+ reflection_name: Optional, if set this will generate the flatbuffer
+ reflection binaries for the schemas.
+ reflection_visibility: The visibility of the generated reflection Fileset.
+ output_to_bindir: Passed to genrule for output to bin directory.
+ """
+ include_paths_cmd = ["-I %s" % (s) for s in include_paths]
+
+ # '$(@D)' when given a single source target will give the appropriate
+ # directory. Appending 'out_prefix' is only necessary when given a build
+ # target with multiple sources.
+ output_directory = (
+ ("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)")
+ )
+ genrule_cmd = " ".join([
+ "for f in $(SRCS); do",
+ "$(location %s)" % (flatc_path),
+ " ".join(flatc_args),
+ " ".join(include_paths_cmd),
+ language_flag,
+ output_directory,
+ "$$f;",
+ "done",
+ ])
+ native.genrule(
+ name = name,
+ srcs = srcs,
+ outs = outs,
+ output_to_bindir = output_to_bindir,
+ compatible_with = compatible_with,
+ tools = includes + [flatc_path],
+ cmd = genrule_cmd,
+ message = "Generating flatbuffer files for %s:" % (name),
+ )
+ if reflection_name:
+ reflection_genrule_cmd = " ".join([
+ "for f in $(SRCS); do",
+ "$(location %s)" % (flatc_path),
+ "-b --schema",
+ " ".join(flatc_args),
+ " ".join(include_paths_cmd),
+ language_flag,
+ output_directory,
+ "$$f;",
+ "done",
+ ])
+ reflection_outs = [
+ (out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1])
+ for s in srcs
+ ]
+ native.genrule(
+ name = "%s_srcs" % reflection_name,
+ srcs = srcs,
+ outs = reflection_outs,
+ output_to_bindir = output_to_bindir,
+ compatible_with = compatible_with,
+ tools = includes + [flatc_path],
+ cmd = reflection_genrule_cmd,
+ message = "Generating flatbuffer reflection binary for %s:" % (name),
+ )
+ # TODO(b/114456773): Make bazel rules proper and supported by flatbuffer
+ # Have to comment this since FilesetEntry is not supported in bazel
+ # skylark.
+ # native.Fileset(
+ # name = reflection_name,
+ # out = "%s_out" % reflection_name,
+ # entries = [
+ # native.FilesetEntry(files = reflection_outs),
+ # ],
+ # visibility = reflection_visibility,
+ # compatible_with = compatible_with,
+ # )
+
+def flatbuffer_cc_library(
+ name,
+ srcs,
+ srcs_filegroup_name = "",
+ out_prefix = "",
+ includes = [],
+ include_paths = [],
+ compatible_with = [],
+ flatc_args = DEFAULT_FLATC_ARGS,
+ visibility = None,
+ srcs_filegroup_visibility = None,
+ gen_reflections = False):
+ '''A cc_library with the generated reader/writers for the given flatbuffer definitions.
+
+ Outs:
+ filegroup([name]_srcs): all generated .h files.
+ filegroup(srcs_filegroup_name if specified, or [name]_includes if not):
+ Other flatbuffer_cc_library's can pass this in for their `includes`
+ parameter, if they depend on the schemas in this library.
+ Fileset([name]_reflection): (Optional) all generated reflection binaries.
+ cc_library([name]): library with sources and flatbuffers deps.
+
+ Remarks:
+ ** Because the genrule used to call flatc does not have any trivial way of
+ computing the output list of files transitively generated by includes and
+ --gen-includes (the default) being defined for flatc, the --gen-includes
+ flag will not work as expected. The way around this is to add a dependency
+ to the flatbuffer_cc_library defined alongside the flatc included Fileset.
+ For example you might define:
+
+ flatbuffer_cc_library(
+ name = "my_fbs",
+ srcs = [ "schemas/foo.fbs" ],
+ includes = [ "//third_party/bazz:bazz_fbs_includes" ],
+ )
+
+ In which foo.fbs includes a few files from the Fileset defined at
+ //third_party/bazz:bazz_fbs_includes. When compiling the library that
+ includes foo_generated.h, and therefore has my_fbs as a dependency, it
+ will fail to find any of the bazz *_generated.h files unless you also
+ add bazz's flatbuffer_cc_library to your own dependency list, e.g.:
+
+ cc_library(
+ name = "my_lib",
+ deps = [
+ ":my_fbs",
+ "//third_party/bazz:bazz_fbs"
+ ],
+ )
+
+ Happy dependent Flatbuffering!
+
+ Args:
+ name: Rule name.
+ srcs: Source .fbs files. Sent in order to the compiler.
+ srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this
+ filegroup into the `includes` parameter of any other
+ flatbuffer_cc_library that depends on this one's schemas.
+ out_prefix: Prepend this path to the front of all generated files. Usually
+ is a directory name.
+ includes: Optional, list of filegroups of schemas that the srcs depend on.
+ ** SEE REMARKS BELOW **
+ include_paths: Optional, list of paths the includes files can be found in.
+ compatible_with: Optional, passed to genrule for environments this rule
+ can be built for
+ flatc_args: Optional list of additional arguments to pass to flatc
+ (e.g. --gen-mutable).
+ visibility: The visibility of the generated cc_library. By default, use the
+ default visibility of the project.
+ srcs_filegroup_visibility: The visibility of the generated srcs filegroup.
+ By default, use the value of the visibility parameter above.
+ gen_reflections: Optional, if true this will generate the flatbuffer
+ reflection binaries for the schemas.
+ '''
+ output_headers = [
+ (out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1])
+ for s in srcs
+ ]
+ reflection_name = "%s_reflection" % name if gen_reflections else ""
+
+ flatbuffer_library_public(
+ name = "%s_srcs" % (name),
+ srcs = srcs,
+ outs = output_headers,
+ language_flag = "-c",
+ out_prefix = out_prefix,
+ includes = includes,
+ include_paths = include_paths,
+ compatible_with = compatible_with,
+ flatc_args = flatc_args,
+ reflection_name = reflection_name,
+ reflection_visibility = visibility,
+ )
+ native.cc_library(
+ name = name,
+ hdrs = output_headers,
+ srcs = output_headers,
+ features = [
+ "-parse_headers",
+ ],
+ deps = [
+ "@flatbuffers//:runtime_cc",
+ ],
+ includes = ["."],
+ linkstatic = 1,
+ visibility = visibility,
+ compatible_with = compatible_with,
+ )
+
+ # A filegroup for the `srcs`. That is, all the schema files for this
+ # Flatbuffer set.
+ native.filegroup(
+ name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name),
+ srcs = srcs,
+ visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility,
+ compatible_with = compatible_with,
+ )
+
+# Custom provider to track dependencies transitively.
+FlatbufferInfo = provider(
+ fields = {
+ "transitive_srcs": "flatbuffer schema definitions.",
+ },
+)
+
+def _flatbuffer_schemas_aspect_impl(target, ctx):
+ _ignore = [target]
+ transitive_srcs = depset()
+ if hasattr(ctx.rule.attr, "deps"):
+ for dep in ctx.rule.attr.deps:
+ if FlatbufferInfo in dep:
+ transitive_srcs = depset(dep[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs])
+ if hasattr(ctx.rule.attr, "srcs"):
+ for src in ctx.rule.attr.srcs:
+ if FlatbufferInfo in src:
+ transitive_srcs = depset(src[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs])
+ for f in src.files:
+ if f.extension == "fbs":
+ transitive_srcs = depset([f], transitive = [transitive_srcs])
+ return [FlatbufferInfo(transitive_srcs = transitive_srcs)]
+
+# An aspect that runs over all dependencies and transitively collects
+# flatbuffer schema files.
+_flatbuffer_schemas_aspect = aspect(
+ attr_aspects = [
+ "deps",
+ "srcs",
+ ],
+ implementation = _flatbuffer_schemas_aspect_impl,
+)
+
+# Rule to invoke the flatbuffer compiler.
+def _gen_flatbuffer_srcs_impl(ctx):
+ outputs = ctx.attr.outputs
+ include_paths = ctx.attr.include_paths
+ if ctx.attr.no_includes:
+ no_includes_statement = ["--no-includes"]
+ else:
+ no_includes_statement = []
+
+ # Need to generate all files in a directory.
+ if not outputs:
+ outputs = [ctx.actions.declare_directory("{}_all".format(ctx.attr.name))]
+ output_directory = outputs[0].path
+ else:
+ outputs = [ctx.actions.declare_file(output) for output in outputs]
+ output_directory = outputs[0].dirname
+
+ deps = depset(ctx.files.srcs + ctx.files.deps, transitive = [
+ dep[FlatbufferInfo].transitive_srcs
+ for dep in ctx.attr.deps
+ if FlatbufferInfo in dep
+ ])
+
+ include_paths_cmd_line = []
+ for s in include_paths:
+ include_paths_cmd_line.extend(["-I", s])
+
+ for src in ctx.files.srcs:
+ ctx.actions.run(
+ inputs = deps,
+ outputs = outputs,
+ executable = ctx.executable._flatc,
+ arguments = [
+ ctx.attr.language_flag,
+ "-o",
+ output_directory,
+ # Allow for absolute imports and referencing of generated files.
+ "-I",
+ "./",
+ "-I",
+ ctx.genfiles_dir.path,
+ "-I",
+ ctx.bin_dir.path,
+ ] + no_includes_statement +
+ include_paths_cmd_line + [
+ "--no-union-value-namespacing",
+ "--gen-object-api",
+ src.path,
+ ],
+ progress_message = "Generating flatbuffer files for {}:".format(src),
+ )
+ return [
+ DefaultInfo(files = depset(outputs)),
+ ]
+
+_gen_flatbuffer_srcs = rule(
+ _gen_flatbuffer_srcs_impl,
+ attrs = {
+ "srcs": attr.label_list(
+ allow_files = [".fbs"],
+ mandatory = True,
+ ),
+ "outputs": attr.string_list(
+ default = [],
+ mandatory = False,
+ ),
+ "deps": attr.label_list(
+ default = [],
+ mandatory = False,
+ aspects = [_flatbuffer_schemas_aspect],
+ ),
+ "include_paths": attr.string_list(
+ default = [],
+ mandatory = False,
+ ),
+ "language_flag": attr.string(
+ mandatory = True,
+ ),
+ "no_includes": attr.bool(
+ default = False,
+ mandatory = False,
+ ),
+ "_flatc": attr.label(
+ default = Label("@flatbuffers//:flatc"),
+ executable = True,
+ cfg = "host",
+ ),
+ },
+ output_to_genfiles = True,
+)
+
+def _concat_flatbuffer_py_srcs_impl(ctx):
+ # Merge all generated python files. The files are concatenated and the
+ # import statements are removed. Finally we import the flatbuffer runtime
+ # library.
+ command = "echo 'import flatbuffers\n' > %s; "
+ command += "for f in $(find %s -name '*.py'); do cat $f | sed '/import flatbuffers/d' >> %s; done "
+ ctx.actions.run_shell(
+ inputs = ctx.attr.deps[0].files,
+ outputs = [ctx.outputs.out],
+ command = command % (
+ ctx.outputs.out.path,
+ ctx.attr.deps[0].files.to_list()[0].path,
+ ctx.outputs.out.path,
+ ),
+ )
+
+_concat_flatbuffer_py_srcs = rule(
+ _concat_flatbuffer_py_srcs_impl,
+ attrs = {
+ "deps": attr.label_list(mandatory = True),
+ },
+ output_to_genfiles = True,
+ outputs = {"out": "%{name}.py"},
+)
+
+def flatbuffer_py_library(
+ name,
+ srcs,
+ deps = [],
+ include_paths = []):
+ """A py_library with the generated reader/writers for the given schema.
+
+ This rule assumes that the schema files define non-conflicting names, so that
+ they can be merged in a single file. This is e.g. the case if only a single
+ namespace is used.
+ The rule call the flatbuffer compiler for all schema files and merges the
+ generated python files into a single file that is wrapped in a py_library.
+
+ Args:
+ name: Rule name. (required)
+ srcs: List of source .fbs files. (required)
+ deps: List of dependencies.
+ include_paths: Optional, list of paths the includes files can be found in.
+ """
+ all_srcs = "{}_srcs".format(name)
+ _gen_flatbuffer_srcs(
+ name = all_srcs,
+ srcs = srcs,
+ language_flag = "--python",
+ deps = deps,
+ include_paths = include_paths,
+ )
+ all_srcs_no_include = "{}_srcs_no_include".format(name)
+ _gen_flatbuffer_srcs(
+ name = all_srcs_no_include,
+ srcs = srcs,
+ language_flag = "--python",
+ deps = deps,
+ no_includes = True,
+ include_paths = include_paths,
+ )
+ concat_py_srcs = "{}_generated".format(name)
+ _concat_flatbuffer_py_srcs(
+ name = concat_py_srcs,
+ deps = [
+ ":{}".format(all_srcs_no_include),
+ ],
+ )
+ native.py_library(
+ name = name,
+ srcs = [
+ ":{}".format(concat_py_srcs),
+ ],
+ srcs_version = "PY2AND3",
+ deps = deps,
+ )
+
+def flatbuffer_java_library(
+ name,
+ srcs,
+ custom_package = "",
+ package_prefix = "",
+ include_paths = DEFAULT_INCLUDE_PATHS,
+ flatc_args = DEFAULT_FLATC_ARGS,
+ visibility = None):
+ """A java library with the generated reader/writers for the given flatbuffer definitions.
+
+ Args:
+ name: Rule name. (required)
+ srcs: List of source .fbs files including all includes. (required)
+ custom_package: Package name of generated Java files. If not specified
+ namespace in the schema files will be used. (optional)
+ package_prefix: like custom_package, but prefixes to the existing
+ namespace. (optional)
+ include_paths: List of paths that includes files can be found in. (optional)
+ flatc_args: List of additional arguments to pass to flatc. (optional)
+ visibility: Visibility setting for the java_library rule. (optional)
+ """
+ out_srcjar = "java_%s_all.srcjar" % name
+ flatbuffer_java_srcjar(
+ name = "%s_srcjar" % name,
+ srcs = srcs,
+ out = out_srcjar,
+ custom_package = custom_package,
+ flatc_args = flatc_args,
+ include_paths = include_paths,
+ package_prefix = package_prefix,
+ )
+
+ native.filegroup(
+ name = "%s.srcjar" % name,
+ srcs = [out_srcjar],
+ )
+
+ native.java_library(
+ name = name,
+ srcs = [out_srcjar],
+ javacopts = ["-source 7 -target 7"],
+ deps = [
+ "@flatbuffers//:runtime_java",
+ ],
+ visibility = visibility,
+ )
+
+def flatbuffer_java_srcjar(
+ name,
+ srcs,
+ out,
+ custom_package = "",
+ package_prefix = "",
+ include_paths = DEFAULT_INCLUDE_PATHS,
+ flatc_args = DEFAULT_FLATC_ARGS):
+ """Generate flatbuffer Java source files.
+
+ Args:
+ name: Rule name. (required)
+ srcs: List of source .fbs files including all includes. (required)
+ out: Output file name. (required)
+ custom_package: Package name of generated Java files. If not specified
+ namespace in the schema files will be used. (optional)
+ package_prefix: like custom_package, but prefixes to the existing
+ namespace. (optional)
+ include_paths: List of paths that includes files can be found in. (optional)
+ flatc_args: List of additional arguments to pass to flatc. (optional)
+ """
+ command_fmt = """set -e
+ tmpdir=$(@D)
+ schemas=$$tmpdir/schemas
+ java_root=$$tmpdir/java
+ rm -rf $$schemas
+ rm -rf $$java_root
+ mkdir -p $$schemas
+ mkdir -p $$java_root
+
+ for src in $(SRCS); do
+ dest=$$schemas/$$src
+ rm -rf $$(dirname $$dest)
+ mkdir -p $$(dirname $$dest)
+ if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then
+ cp -f $$src $$dest
+ else
+ if [ -z "{package_prefix}" ]; then
+ sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest
+ else
+ sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest
+ fi
+ fi
+ done
+
+ flatc_arg_I="-I $$tmpdir/schemas"
+ for include_path in {include_paths}; do
+ flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path"
+ done
+
+ flatc_additional_args=
+ for arg in {flatc_args}; do
+ flatc_additional_args="$$flatc_additional_args $$arg"
+ done
+
+ for src in $(SRCS); do
+ $(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src
+ done
+
+ $(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root
+ """
+ genrule_cmd = command_fmt.format(
+ package_name = native.package_name(),
+ custom_package = custom_package,
+ package_prefix = package_prefix,
+ flatc_path = flatc_path,
+ zip_files = zip_files,
+ include_paths = " ".join(include_paths),
+ flatc_args = " ".join(flatc_args),
+ )
+
+ native.genrule(
+ name = name,
+ srcs = srcs,
+ outs = [out],
+ tools = [flatc_path, zip_files],
+ cmd = genrule_cmd,
+ )
+
+def flatbuffer_android_library(
+ name,
+ srcs,
+ custom_package = "",
+ package_prefix = "",
+ include_paths = DEFAULT_INCLUDE_PATHS,
+ flatc_args = DEFAULT_FLATC_ARGS,
+ visibility = None):
+ """An android_library with the generated reader/writers for the given flatbuffer definitions.
+
+ Args:
+ name: Rule name. (required)
+ srcs: List of source .fbs files including all includes. (required)
+ custom_package: Package name of generated Java files. If not specified
+ namespace in the schema files will be used. (optional)
+ package_prefix: like custom_package, but prefixes to the existing
+ namespace. (optional)
+ include_paths: List of paths that includes files can be found in. (optional)
+ flatc_args: List of additional arguments to pass to flatc. (optional)
+ visibility: Visibility setting for the android_library rule. (optional)
+ """
+ out_srcjar = "android_%s_all.srcjar" % name
+ flatbuffer_java_srcjar(
+ name = "%s_srcjar" % name,
+ srcs = srcs,
+ out = out_srcjar,
+ custom_package = custom_package,
+ flatc_args = flatc_args,
+ include_paths = include_paths,
+ package_prefix = package_prefix,
+ )
+
+ native.filegroup(
+ name = "%s.srcjar" % name,
+ srcs = [out_srcjar],
+ )
+
+ # To support org.checkerframework.dataflow.qual.Pure.
+ checkerframework_annotations = [
+ "@org_checkerframework_qual",
+ ] if "--java-checkerframework" in flatc_args else []
+
+ android_library(
+ name = name,
+ srcs = [out_srcjar],
+ javacopts = ["-source 7 -target 7"],
+ visibility = visibility,
+ deps = [
+ "@flatbuffers//:runtime_android",
+ ] + checkerframework_annotations,
+ )
diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl
new file mode 100644
index 00000000..dea463f2
--- /dev/null
+++ b/third_party/flatbuffers/workspace.bzl
@@ -0,0 +1,19 @@
+"""Loads the Flatbuffers library, used by TF Lite."""
+
+load("//third_party:repo.bzl", "third_party_http_archive")
+
+def repo():
+ third_party_http_archive(
+ name = "flatbuffers",
+ strip_prefix = "flatbuffers-1.12.0",
+ sha256 = "62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
+ "https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz",
+ ],
+ build_file = "//third_party/flatbuffers:BUILD.bazel",
+ delete = ["build_defs.bzl"],
+ link_files = {
+ "//third_party/flatbuffers:build_defs.bzl": "build_defs.bzl",
+ },
+ )
diff --git a/third_party/gflags/BUILD b/third_party/gflags/BUILD
new file mode 100644
index 00000000..82bab3ff
--- /dev/null
+++ b/third_party/gflags/BUILD
@@ -0,0 +1 @@
+# This empty BUILD file is required to make Bazel treat this directory as a package.
diff --git a/third_party/gflags/fix_android_pthread_link.patch b/third_party/gflags/fix_android_pthread_link.patch
new file mode 100644
index 00000000..9a0b3511
--- /dev/null
+++ b/third_party/gflags/fix_android_pthread_link.patch
@@ -0,0 +1,32 @@
+diff --git a/BUILD b/BUILD
+index 0a5c9eb..d836578 100644
+--- a/BUILD
++++ b/BUILD
+@@ -6,6 +6,11 @@ licenses(["notice"])
+
+ exports_files(["src/gflags_complections.sh", "COPYING.txt"])
+
++config_setting(
++ name = "android",
++ values = {"crosstool_top": "//external:android/crosstool"},
++)
++
+ load(":bazel/gflags.bzl", "gflags_sources", "gflags_library")
+ (hdrs, srcs) = gflags_sources(namespace=["gflags", "google"])
+ gflags_library(hdrs=hdrs, srcs=srcs, threads=0)
+diff --git a/bazel/gflags.bzl b/bazel/gflags.bzl
+index cd0edad..5c1d8b5 100644
+--- a/bazel/gflags.bzl
++++ b/bazel/gflags.bzl
+@@ -77,7 +77,10 @@ def gflags_library(hdrs=[], srcs=[], threads=1):
+ ]
+ linkopts = []
+ if threads:
+- linkopts.append("-lpthread")
++ linkopts += select({
++ "//:android": [],
++ "//conditions:default": ["-lpthread"],
++ })
+ else:
+ name += "_nothreads"
+ copts.append("-DNO_THREADS") \ No newline at end of file
diff --git a/third_party/gflags/workspace.bzl b/third_party/gflags/workspace.bzl
new file mode 100644
index 00000000..194a9d3f
--- /dev/null
+++ b/third_party/gflags/workspace.bzl
@@ -0,0 +1,16 @@
+"""Loads the GFlags repo and patch it with android linkopt fix."""
+
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+
+def repo():
+ http_archive(
+ name = "com_github_gflags_gflags",
+ sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
+ strip_prefix = "gflags-2.2.1",
+ urls = [
+ "http://mirror.tensorflow.org/github.com/gflags/gflags/archive/v2.2.1.tar.gz",
+ "https://github.com/gflags/gflags/archive/v2.2.1.tar.gz",
+ ],
+ patches = ["@//third_party/gflags:fix_android_pthread_link.patch"],
+ patch_args = ["-p1"],
+ )
diff --git a/third_party/google_toolbox_for_mac.BUILD b/third_party/google_toolbox_for_mac.BUILD
new file mode 100644
index 00000000..8d7fecf3
--- /dev/null
+++ b/third_party/google_toolbox_for_mac.BUILD
@@ -0,0 +1,22 @@
+# Description:
+# A collection of source from different Google projects that may be of use to
+# developers working other Mac projects.
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+exports_files(
+ ["UnitTest-Info.plist"],
+ visibility = ["//visibility:public"],
+)
+
+objc_library(
+ name = "GTM_Defines",
+ hdrs = ["GTMDefines.h"],
+ includes = ["."],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/icu.BUILD b/third_party/icu.BUILD
new file mode 100644
index 00000000..7749dda0
--- /dev/null
+++ b/third_party/icu.BUILD
@@ -0,0 +1,97 @@
+"""Builds ICU library."""
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files([
+ "icu4c/LICENSE",
+ "icu4j/main/shared/licenses/LICENSE",
+])
+
+cc_library(
+ name = "headers",
+ hdrs = glob(["icu4c/source/common/unicode/*.h"]),
+ includes = [
+ "icu4c/source/common",
+ ],
+ deps = [
+ ],
+)
+
+cc_library(
+ name = "common",
+ hdrs = glob(["icu4c/source/common/unicode/*.h"]),
+ includes = [
+ "icu4c/source/common",
+ ],
+ deps = [
+ ":icuuc",
+ ],
+)
+
+alias(
+ name = "nfkc",
+ actual = ":common",
+)
+
+alias(
+ name = "nfkc_cf",
+ actual = ":common",
+)
+
+cc_library(
+ name = "icuuc",
+ srcs = glob(
+ [
+ "icu4c/source/common/*.c",
+ "icu4c/source/common/*.cpp",
+ "icu4c/source/stubdata/*.cpp",
+ ],
+ ),
+ hdrs = glob([
+ "icu4c/source/common/*.h",
+ ]),
+ copts = [
+ "-DU_COMMON_IMPLEMENTATION",
+ ] + select({
+ ":android": [
+ "-fdata-sections",
+ "-DU_HAVE_NL_LANGINFO_CODESET=0",
+ "-Wno-deprecated-declarations",
+ ],
+ ":apple": [
+ "-Wno-shorten-64-to-32",
+ "-Wno-unused-variable",
+ ],
+ ":windows": [
+ "/utf-8",
+ "/DLOCALE_ALLOW_NEUTRAL_NAMES=0",
+ ],
+ "//conditions:default": [],
+ }),
+ tags = ["requires-rtti"],
+ visibility = [
+ "//visibility:private",
+ ],
+ deps = [
+ ":headers",
+ ],
+)
+
+config_setting(
+ name = "android",
+ values = {"crosstool_top": "//external:android/crosstool"},
+)
+
+config_setting(
+ name = "apple",
+ values = {"cpu": "darwin"},
+)
+
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
diff --git a/third_party/libyuv.BUILD b/third_party/libyuv.BUILD
new file mode 100644
index 00000000..4b39a8c0
--- /dev/null
+++ b/third_party/libyuv.BUILD
@@ -0,0 +1,25 @@
+# Description:
+# The libyuv package provides implementation yuv image conversion, rotation
+# and scaling.
+
+licenses(["notice"]) # BSD license
+
+exports_files(["LICENSE"])
+
+cc_library(
+ name = "libyuv",
+ srcs = glob(
+ [
+ "source/*.cc",
+ "include/libyuv/*.h",
+ ],
+ ),
+ hdrs = [
+ "include/libyuv.h",
+ "include/libyuv/compare.h",
+ "include/libyuv/convert.h",
+ "include/libyuv/video_common.h",
+ ],
+ includes = ["include"],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/libzip.BUILD b/third_party/libzip.BUILD
new file mode 100644
index 00000000..b69ccf41
--- /dev/null
+++ b/third_party/libzip.BUILD
@@ -0,0 +1,189 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+load("@org_tensorflow_lite_support//tensorflow_lite_support/tools:build_rules/expand_template.bzl", "cmake_substitutions", "expand_template")
+
+_CMAKE_VARIABLES = {
+ "INT16_T_LIBZIP": 2,
+ "INT32_T_LIBZIP": 4,
+ "INT64_T_LIBZIP": 8,
+ "INT8_T_LIBZIP": 1,
+ "INT_LIBZIP": 4,
+ "LIBZIP_TYPES_INCLUDE": "#include <stdint.h>",
+ "LONG_LIBZIP": 8,
+ "LONG_LONG_LIBZIP": 8,
+ "PACKAGE_VERSION": "1.5.1",
+ "PACKAGE_VERSION_MAJOR": "1",
+ "PACKAGE_VERSION_MICRO": "1",
+ "PACKAGE_VERSION_MINOR": "5",
+ "SHORT_LIBZIP": 2,
+ "SIZEOF_OFF_T": 8,
+ "SIZE_T_LIBZIP": 8,
+ "SSIZE_T_LIBZIP": 8,
+ "UINT16_T_LIBZIP": 2,
+ "UINT32_T_LIBZIP": 4,
+ "UINT64_T_LIBZIP": 8,
+ "UINT8_T_LIBZIP": 1,
+ "__INT16_LIBZIP": None,
+ "__INT32_LIBZIP": None,
+ "__INT64_LIBZIP": None,
+ "__INT8_LIBZIP": None,
+}
+
+_CMAKE_VARIABLES.update(dict([
+ (
+ "ZIP_{sign}INT{size}_T".format(
+ sign = sign.upper(),
+ size = size,
+ ),
+ "{sign}int{size}_t".format(
+ sign = sign.lower(),
+ size = size,
+ ),
+ )
+ for sign in ("U", "")
+ for size in (8, 16, 32, 64)
+]))
+
+_SUBSTITUTIONS = {
+ "@PACKAGE@": "libzip",
+ "@VERSION@": "1.5.1", # Keep in sync with actual package!
+}
+
+_DEFINES = {
+ "HAVE_CLONEFILE": False,
+ "HAVE_COMMONCRYPTO": False,
+ "HAVE_CRYPTO": False,
+ "HAVE_DIRENT_H": False,
+ "HAVE_FICLONERANGE": False,
+ "HAVE_FILENO": True,
+ "HAVE_FSEEK": True,
+ "HAVE_FSEEKO": True,
+ "HAVE_FTELLO": True,
+ "HAVE_FTS_H": True,
+ "HAVE_GETPROGNAME": False,
+ "HAVE_GNUTLS": False,
+ "HAVE_LIBBZ2": False,
+ "HAVE_MKSTEMP": True,
+ "HAVE_NDIR_H": False,
+ "HAVE_OPEN": True,
+ "HAVE_OPENSSL": False,
+ "HAVE_SETMODE": False,
+ "HAVE_SHARED": True,
+ "HAVE_SNPRINTF": True,
+ "HAVE_SSIZE_T_LIBZIP": True,
+ "HAVE_STDBOOL_H": True,
+ "HAVE_STRCASECMP": True,
+ "HAVE_STRDUP": True,
+ "HAVE_STRICMP": False,
+ "HAVE_STRINGS_H": True,
+ "HAVE_STRTOLL": True,
+ "HAVE_STRTOULL": True,
+ "HAVE_STRUCT_TM_TM_ZONE": False,
+ "HAVE_SYS_DIR_H": False,
+ "HAVE_SYS_NDIR_H": False,
+ "HAVE_UNISTD_H": True,
+ "HAVE__CHMOD": False,
+ "HAVE__CLOSE": False,
+ "HAVE__DUP": False,
+ "HAVE__FDOPEN": False,
+ "HAVE__FILENO": False,
+ "HAVE__OPEN": False,
+ "HAVE__SETMODE": False,
+ "HAVE__SNPRINTF": False,
+ "HAVE__STRDUP": False,
+ "HAVE__STRICMP": False,
+ "HAVE__STRTOI64": False,
+ "HAVE__STRTOUI64": False,
+ "HAVE__UMASK": False,
+ "HAVE__UNLINK": False,
+ "HAVE___PROGNAME": False,
+ "WORDS_BIGENDIAN": False,
+}
+
+_DEFINES.update(dict([(
+ key,
+ value != None,
+) for key, value in _CMAKE_VARIABLES.items()]))
+
+_SUBSTITUTIONS.update(cmake_substitutions(
+ defines = _DEFINES,
+ vars = _CMAKE_VARIABLES,
+))
+
+expand_template(
+ name = "config_h",
+ out = "config.h",
+ substitutions = _SUBSTITUTIONS,
+ template = "cmake-config.h.in",
+)
+
+_VARS = {
+ "LIBZIP_TYPES_INCLUDE": "#include <stdint.h>",
+ "PACKAGE_VERSION": "1.5.1",
+ "PACKAGE_VERSION_MAJOR": "1",
+ "PACKAGE_VERSION_MICRO": "1",
+ "PACKAGE_VERSION_MINOR": "5",
+}
+
+_VARS.update(dict([
+ (
+ "ZIP_{sign}INT{size}_T".format(
+ sign = sign.upper(),
+ size = size,
+ ),
+ "{sign}int{size}_t".format(
+ sign = sign.lower(),
+ size = size,
+ ),
+ )
+ for sign in ("U", "")
+ for size in (8, 16, 32, 64)
+]))
+
+expand_template(
+ name = "zipconf_h",
+ out = "lib/zipconf.h",
+ substitutions = cmake_substitutions(
+ defines = {
+ "LIBZIP_VERSION": True,
+ "LIBZIP_VERSION_MAJOR": True,
+ "LIBZIP_VERSION_MICRO": True,
+ "LIBZIP_VERSION_MINOR": True,
+ "ZIP_STATIC": False,
+ },
+ vars = _VARS,
+ ),
+ template = "cmake-zipconf.h.in",
+)
+
+cc_library(
+ name = "zip",
+ srcs = glob(
+ [
+ "lib/*.c",
+ "lib/*.h",
+ ],
+ exclude = [
+ "lib/*win32*",
+ "lib/zip_random_uwp.c",
+ "lib/*crypto*",
+ "lib/*aes*",
+ "lib/*bzip2*",
+ ],
+ ) + [
+ "config.h",
+ ],
+ hdrs = [
+ "lib/zip.h",
+ "lib/zipconf.h",
+ ],
+ copts = [
+ "-DHAVE_CONFIG_H",
+ ],
+ includes = ["lib"],
+ deps = [
+ "@zlib",
+ ],
+)
diff --git a/third_party/py/BUILD b/third_party/py/BUILD
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/third_party/py/BUILD
diff --git a/third_party/py/BUILD.tpl b/third_party/py/BUILD.tpl
new file mode 100644
index 00000000..cc0e013b
--- /dev/null
+++ b/third_party/py/BUILD.tpl
@@ -0,0 +1,31 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+# Point both runtimes to the same python binary to ensure we always
+# use the python binary specified by ./configure.py script.
+load("@bazel_tools//tools/python:toolchain.bzl", "py_runtime_pair")
+
+py_runtime(
+ name = "py2_runtime",
+ interpreter_path = "%{PYTHON_BIN_PATH}",
+ python_version = "PY2",
+)
+
+py_runtime(
+ name = "py3_runtime",
+ interpreter_path = "%{PYTHON_BIN_PATH}",
+ python_version = "PY3",
+)
+
+py_runtime_pair(
+ name = "py_runtime_pair",
+ py2_runtime = ":py2_runtime",
+ py3_runtime = ":py3_runtime",
+)
+
+toolchain(
+ name = "py_toolchain",
+ toolchain = ":py_runtime_pair",
+ toolchain_type = "@bazel_tools//tools/python:toolchain_type",
+)
diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl
new file mode 100644
index 00000000..6601d7f2
--- /dev/null
+++ b/third_party/py/python_configure.bzl
@@ -0,0 +1,71 @@
+"""Repository rule for Python autoconfiguration.
+
+`python_configure` depends on the following environment variables:
+
+ * `PYTHON_BIN_PATH`: location of python binary.
+"""
+
+_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
+
+def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
+ if not out:
+ out = tpl
+ repository_ctx.template(
+ out,
+ Label("//third_party/py:%s.tpl" % tpl),
+ substitutions,
+ )
+
+def _fail(msg):
+ """Output failure message when auto configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
+
+def _get_python_bin(repository_ctx):
+ """Gets the python bin path."""
+ python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
+ if python_bin != None:
+ return python_bin
+ python_bin_path = repository_ctx.which("python")
+ if python_bin_path != None:
+ return str(python_bin_path)
+ _fail("Cannot find python in PATH, please make sure " +
+ "python is installed and add its directory in PATH, or --define " +
+ "%s='/something/else'.\nPATH=%s" % (
+ _PYTHON_BIN_PATH,
+ repository_ctx.os.environ.get("PATH", ""),
+ ))
+
+def _create_local_python_repository(repository_ctx):
+ """Creates the repository containing files set up to build with Python."""
+ python_bin = _get_python_bin(repository_ctx)
+ _tpl(repository_ctx, "BUILD", {
+ "%{PYTHON_BIN_PATH}": python_bin,
+ })
+
+def _python_autoconf_impl(repository_ctx):
+ """Implementation of the python_autoconf repository rule."""
+ _create_local_python_repository(repository_ctx)
+
+python_configure = repository_rule(
+ implementation = _python_autoconf_impl,
+ environ = [
+ _PYTHON_BIN_PATH,
+ ],
+)
+"""Detects and configures the local Python toolchain.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+load("//third_party/py:python_configure.bzl", "python_configure")
+
+python_configure(name = "local_config_py_toolchain")
+
+register_toolchains("@local_config_py_toolchain//:py_toolchain")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""
diff --git a/third_party/pybind11.BUILD b/third_party/pybind11.BUILD
new file mode 100644
index 00000000..2f1ada61
--- /dev/null
+++ b/third_party/pybind11.BUILD
@@ -0,0 +1,25 @@
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "pybind11",
+ hdrs = glob(
+ include = [
+ "include/pybind11/*.h",
+ "include/pybind11/detail/*.h",
+ ],
+ exclude = [
+ "include/pybind11/common.h",
+ "include/pybind11/eigen.h",
+ ],
+ ),
+ copts = [
+ "-fexceptions",
+ "-Wno-undefined-inline",
+ "-Wno-pragma-once-outside-header",
+ ],
+ includes = ["include"],
+ strip_include_prefix = "include",
+ deps = [
+ "@org_tensorflow//third_party/python_runtime:headers",
+ ],
+)
diff --git a/third_party/python_runtime/BUILD b/third_party/python_runtime/BUILD
new file mode 100644
index 00000000..2a160919
--- /dev/null
+++ b/third_party/python_runtime/BUILD
@@ -0,0 +1,8 @@
+licenses(["notice"]) # New BSD, Python Software Foundation
+
+package(default_visibility = ["//visibility:public"])
+
+alias(
+ name = "headers",
+ actual = "@local_config_python//:python_headers",
+)
diff --git a/third_party/repo.bzl b/third_party/repo.bzl
new file mode 100644
index 00000000..c9c6a834
--- /dev/null
+++ b/third_party/repo.bzl
@@ -0,0 +1,152 @@
+# Copyright 2020 The TensorFlow Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for defining TensorFlow Lite Support Bazel dependencies."""
+
+_SINGLE_URL_WHITELIST = []
+
+def _is_windows(ctx):
+ return ctx.os.name.lower().find("windows") != -1
+
+def _wrap_bash_cmd(ctx, cmd):
+ if _is_windows(ctx):
+ bazel_sh = _get_env_var(ctx, "BAZEL_SH")
+ if not bazel_sh:
+ fail("BAZEL_SH environment variable is not set")
+ cmd = [bazel_sh, "-l", "-c", " ".join(["\"%s\"" % s for s in cmd])]
+ return cmd
+
+def _get_env_var(ctx, name):
+ if name in ctx.os.environ:
+ return ctx.os.environ[name]
+ else:
+ return None
+
+# Checks if we should use the system lib instead of the bundled one
+def _use_system_lib(ctx, name):
+ syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS")
+ if syslibenv:
+ for n in syslibenv.strip().split(","):
+ if n.strip() == name:
+ return True
+ return False
+
+# Executes specified command with arguments and calls 'fail' if it exited with
+# non-zero code
+def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
+ result = repo_ctx.execute(cmd_and_args, timeout = 60)
+ if result.return_code != 0:
+ fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" +
+ "Stderr: {3}").format(
+ " ".join([str(x) for x in cmd_and_args]),
+ result.return_code,
+ result.stdout,
+ result.stderr,
+ ))
+
+# Apply a patch_file to the repository root directory
+# Runs 'patch -p1' on both Windows and Unix.
+def _apply_patch(ctx, patch_file):
+ patch_command = ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)]
+ cmd = _wrap_bash_cmd(ctx, patch_command)
+ _execute_and_check_ret_code(ctx, cmd)
+
+def _apply_delete(ctx, paths):
+ for path in paths:
+ if path.startswith("/"):
+ fail("refusing to rm -rf path starting with '/': " + path)
+ if ".." in path:
+ fail("refusing to rm -rf path containing '..': " + path)
+ cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths])
+ _execute_and_check_ret_code(ctx, cmd)
+
+def _third_party_http_archive(ctx):
+ """Downloads and creates Bazel repos for dependencies.
+
+ This is a swappable replacement for both http_archive() and
+ new_http_archive() that offers some additional features. It also helps
+ ensure best practices are followed.
+ """
+ if ("mirror.tensorflow.org" not in ctx.attr.urls[0] and
+ (len(ctx.attr.urls) < 2 and
+ ctx.attr.name not in _SINGLE_URL_WHITELIST.to_list())):
+ fail("third_party_http_archive(urls) must have redundant URLs. The " +
+ "mirror.tensorflow.org URL must be present and it must come first. " +
+ "Even if you don't have permission to mirror the file, please " +
+ "put the correctly formatted mirror URL there anyway, because " +
+ "someone will come along shortly thereafter and mirror the file.")
+
+ use_syslib = _use_system_lib(ctx, ctx.attr.name)
+
+ # Use "BUILD.bazel" to avoid conflict with third party projects that contain a
+ # file or directory called "BUILD"
+ buildfile_path = ctx.path("BUILD.bazel")
+
+ if use_syslib:
+ if ctx.attr.system_build_file == None:
+ fail("Bazel was configured with TF_SYSTEM_LIBS to use a system " +
+ "library for %s, but no system build file for %s was configured. " +
+ "Please add a system_build_file attribute to the repository rule" +
+ "for %s." % (ctx.attr.name, ctx.attr.name, ctx.attr.name))
+ ctx.symlink(Label(ctx.attr.system_build_file), buildfile_path)
+
+ else:
+ ctx.download_and_extract(
+ ctx.attr.urls,
+ "",
+ ctx.attr.sha256,
+ ctx.attr.type,
+ ctx.attr.strip_prefix,
+ )
+ if ctx.attr.delete:
+ _apply_delete(ctx, ctx.attr.delete)
+ if ctx.attr.patch_file != None:
+ _apply_patch(ctx, ctx.attr.patch_file)
+ ctx.symlink(Label(ctx.attr.build_file), buildfile_path)
+
+ link_dict = {}
+ if use_syslib:
+ link_dict.update(ctx.attr.system_link_files)
+
+ for internal_src, external_dest in ctx.attr.link_files.items():
+ # if syslib and link exists in both, use the system one
+ if external_dest not in link_dict.values():
+ link_dict[internal_src] = external_dest
+
+ for internal_src, external_dest in link_dict.items():
+ ctx.symlink(Label(internal_src), ctx.path(external_dest))
+
+# For link_files, specify each dict entry as:
+# "//path/to/source:file": "localfile"
+third_party_http_archive = repository_rule(
+ attrs = {
+ "sha256": attr.string(mandatory = True),
+ "urls": attr.string_list(
+ mandatory = True,
+ allow_empty = False,
+ ),
+ "strip_prefix": attr.string(),
+ "type": attr.string(),
+ "delete": attr.string_list(),
+ "build_file": attr.string(mandatory = True),
+ "system_build_file": attr.string(mandatory = False),
+ "patch_file": attr.label(),
+ "link_files": attr.string_dict(),
+ "system_link_files": attr.string_dict(),
+ },
+ environ = [
+ "TF_SYSTEM_LIBS",
+ ],
+ implementation = _third_party_http_archive,
+)
diff --git a/third_party/six.BUILD b/third_party/six.BUILD
new file mode 100644
index 00000000..a1b2f7b2
--- /dev/null
+++ b/third_party/six.BUILD
@@ -0,0 +1,14 @@
+# Description:
+# Six provides simple utilities for wrapping over differences between Python 2
+# and Python 3.
+
+licenses(["notice"]) # MIT
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "six",
+ srcs = ["six.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/stblib.BUILD b/third_party/stblib.BUILD
new file mode 100644
index 00000000..f1c361ac
--- /dev/null
+++ b/third_party/stblib.BUILD
@@ -0,0 +1,26 @@
+# Description:
+# Single-file C++ image decoding and encoding libraries
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # MIT license
+
+exports_files(["LICENSE"])
+
+cc_library(
+ name = "stb_image",
+ hdrs = ["stb_image.h"],
+ copts = [
+ "-Wno-unused-function",
+ "$(STACK_FRAME_UNLIMITED)",
+ ],
+ includes = ["."],
+)
+
+cc_library(
+ name = "stb_image_write",
+ hdrs = ["stb_image_write.h"],
+ includes = ["."],
+)
diff --git a/third_party/tensorflow/BUILD b/third_party/tensorflow/BUILD
new file mode 100644
index 00000000..ac039c46
--- /dev/null
+++ b/third_party/tensorflow/BUILD
@@ -0,0 +1 @@
+# placeholder to make the directory a bazel package.
diff --git a/third_party/tensorflow/BUILD.tpl b/third_party/tensorflow/BUILD.tpl
new file mode 100644
index 00000000..095021ed
--- /dev/null
+++ b/third_party/tensorflow/BUILD.tpl
@@ -0,0 +1,18 @@
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "tf_header_lib",
+ hdrs = [":tf_header_include"],
+ includes = ["include"],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "libtensorflow_framework",
+ srcs = [":libtensorflow_framework.so"],
+ visibility = ["//visibility:public"],
+)
+
+%{TF_HEADER_GENRULE}
+%{TF_SHARED_LIBRARY_GENRULE}
+
diff --git a/third_party/tensorflow/tf_configure.bzl b/third_party/tensorflow/tf_configure.bzl
new file mode 100644
index 00000000..32826255
--- /dev/null
+++ b/third_party/tensorflow/tf_configure.bzl
@@ -0,0 +1,224 @@
+"""Setup TensorFlow as external dependency"""
+
+_TF_HEADER_DIR = "TF_HEADER_DIR"
+_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR"
+_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"
+
+def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
+ if not out:
+ out = tpl
+ repository_ctx.template(
+ out,
+ Label("//third_party/tensorflow:%s.tpl" % tpl),
+ substitutions,
+ )
+
+def _fail(msg):
+ """Output failure message when auto configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
+
+def _is_windows(repository_ctx):
+ """Returns true if the host operating system is windows."""
+ os_name = repository_ctx.os.name.lower()
+ if os_name.find("windows") != -1:
+ return True
+ return False
+
+def _execute(
+ repository_ctx,
+ cmdline,
+ error_msg = None,
+ error_details = None,
+ empty_stdout_fine = False):
+ """Executes an arbitrary shell command.
+
+ Helper for executes an arbitrary shell command.
+
+ Args:
+ repository_ctx: the repository_ctx object.
+ cmdline: list of strings, the command to execute.
+ error_msg: string, a summary of the error if the command fails.
+ error_details: string, details about the error or steps to fix it.
+ empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
+ it's an error.
+
+ Returns:
+ The result of repository_ctx.execute(cmdline).
+ """
+ result = repository_ctx.execute(cmdline)
+ if result.stderr or not (empty_stdout_fine or result.stdout):
+ _fail("\n".join([
+ error_msg.strip() if error_msg else "Repository command failed",
+ result.stderr.strip(),
+ error_details if error_details else "",
+ ]))
+ return result
+
+def _read_dir(repository_ctx, src_dir):
+ """Returns a string with all files in a directory.
+
+ Finds all files inside a directory, traversing subfolders and following
+ symlinks. The returned string contains the full path of all files
+ separated by line breaks.
+
+ Args:
+ repository_ctx: the repository_ctx object.
+ src_dir: directory to find files from.
+
+ Returns:
+ A string of all files inside the given dir.
+ """
+ if _is_windows(repository_ctx):
+ src_dir = src_dir.replace("/", "\\")
+ find_result = _execute(
+ repository_ctx,
+ ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
+ empty_stdout_fine = True,
+ )
+
+ # src_files will be used in genrule.outs where the paths must
+ # use forward slashes.
+ result = find_result.stdout.replace("\\", "/")
+ else:
+ find_result = _execute(
+ repository_ctx,
+ ["find", src_dir, "-follow", "-type", "f"],
+ empty_stdout_fine = True,
+ )
+ result = find_result.stdout
+ return result
+
+def _genrule(genrule_name, command, outs):
+ """Returns a string with a genrule.
+
+ Genrule executes the given command and produces the given outputs.
+
+ Args:
+ genrule_name: A unique name for genrule target.
+ command: The command to run.
+ outs: A list of files generated by this rule.
+
+ Returns:
+ A genrule target.
+ """
+ return (
+ "genrule(\n" +
+ ' name = "' +
+ genrule_name + '",\n' +
+ " outs = [\n" +
+ outs +
+ "\n ],\n" +
+ ' cmd = """\n' +
+ command +
+ '\n """,\n' +
+ ")\n"
+ )
+
+def _norm_path(path):
+ """Returns a path with '/' and remove the trailing slash."""
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ return path
+
+def _symlink_genrule_for_dir(
+ repository_ctx,
+ src_dir,
+ dest_dir,
+ genrule_name,
+ src_files = [],
+ dest_files = [],
+ tf_pip_dir_rename_pair = []):
+ """Returns a genrule to symlink(or copy if on Windows) a set of files.
+
+ If src_dir is passed, files will be read from the given directory; otherwise
+ we assume files are in src_files and dest_files.
+
+ Args:
+ repository_ctx: the repository_ctx object.
+ src_dir: source directory.
+ dest_dir: directory to create symlink in.
+ genrule_name: genrule name.
+ src_files: list of source files instead of src_dir.
+ dest_files: list of corresonding destination files.
+ tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to
+ replace. For example, in TF pip package, the source code is under
+ "tensorflow_core", and we might want to replace it with
+ "tensorflow" to match the header includes.
+ Returns:
+ genrule target that creates the symlinks.
+ """
+
+ # Check that tf_pip_dir_rename_pair has the right length
+ tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair)
+ if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2:
+ _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len)
+
+ if src_dir != None:
+ src_dir = _norm_path(src_dir)
+ dest_dir = _norm_path(dest_dir)
+ files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
+
+ # Create a list with the src_dir stripped to use for outputs.
+ if tf_pip_dir_rename_pair_len:
+ dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines()
+ else:
+ dest_files = files.replace(src_dir, "").splitlines()
+ src_files = files.splitlines()
+ command = []
+ outs = []
+ for i in range(len(dest_files)):
+ if dest_files[i] != "":
+ # If we have only one file to link we do not want to use the dest_dir, as
+ # $(@D) will include the full path to the file.
+ dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
+
+ # Copy the headers to create a sandboxable setup.
+ cmd = "cp -f"
+ command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
+ outs.append(' "' + dest_dir + dest_files[i] + '",')
+ dest_dir = "abc"
+ genrule = _genrule(
+ genrule_name,
+ " && ".join(command),
+ "\n".join(outs),
+ )
+ return genrule
+
+def _tf_pip_impl(repository_ctx):
+ tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR]
+ tf_header_rule = _symlink_genrule_for_dir(
+ repository_ctx,
+ tf_header_dir,
+ "include",
+ "tf_header_include",
+ tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"],
+ )
+
+ tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR]
+ tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
+ tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
+ tf_shared_library_rule = _symlink_genrule_for_dir(
+ repository_ctx,
+ None,
+ "",
+ "libtensorflow_framework.so",
+ [tf_shared_library_path],
+ ["_pywrap_tensorflow_internal.lib" if _is_windows(repository_ctx) else "libtensorflow_framework.so"],
+ )
+
+ _tpl(repository_ctx, "BUILD", {
+ "%{TF_HEADER_GENRULE}": tf_header_rule,
+ "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule,
+ })
+
+tf_configure = repository_rule(
+ implementation = _tf_pip_impl,
+ environ = [
+ _TF_HEADER_DIR,
+ _TF_SHARED_LIBRARY_DIR,
+ _TF_SHARED_LIBRARY_NAME,
+ ],
+)
diff --git a/third_party/tensorflow_lite_ios_build.patch b/third_party/tensorflow_lite_ios_build.patch
new file mode 100644
index 00000000..786e46bc
--- /dev/null
+++ b/third_party/tensorflow_lite_ios_build.patch
@@ -0,0 +1,40 @@
+diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD
+similarity index 97%
+rename from tensorflow/lite/experimental/ios/BUILD.apple
+rename to tensorflow/lite/experimental/ios/BUILD
+index cce0c4df..49eba35f 100644
+--- a/tensorflow/lite/experimental/ios/BUILD.apple
++++ b/tensorflow/lite/experimental/ios/BUILD
+@@ -22,8 +22,7 @@ sh_binary(
+ "hide_symbols_with_allowlist.sh",
+ ],
+ visibility = [
+- "//tensorflow/lite:__subpackages__",
+- "//tensorflow_lite_support:__subpackages__",
++ "//visibility:public",
+ ],
+ )
+
+diff --git a/tensorflow/lite/experimental/ios/ios.bzl b/tensorflow/lite/experimental/ios/ios.bzl
+index 63747eb8..07bcb49d 100644
+--- a/tensorflow/lite/experimental/ios/ios.bzl
++++ b/tensorflow/lite/experimental/ios/ios.bzl
+@@ -60,7 +60,7 @@ def tflite_ios_static_framework(
+ "BUNDLE_NAME=\"" + bundle_name + "\" " +
+ "ALLOWLIST_FILE_PATH=\"$(location " + allowlist_symbols_file + ")\" " +
+ "OUTPUT=\"$(OUTS)\" " +
+- "\"$(location //tensorflow/lite/experimental/ios:hide_symbols_with_allowlist)\"")
++ "\"$(location @org_tensorflow//tensorflow/lite/experimental/ios:hide_symbols_with_allowlist)\"")
+
+ native.genrule(
+ name = name,
+@@ -68,7 +68,7 @@ def tflite_ios_static_framework(
+ outs = [name + ".zip"],
+ cmd = cmd,
+ tools = [
+- "//tensorflow/lite/experimental/ios:hide_symbols_with_allowlist",
++ "@org_tensorflow//tensorflow/lite/experimental/ios:hide_symbols_with_allowlist",
+ ],
+ )
+
+
diff --git a/third_party/tensorflow_text_remove_tf_deps.patch b/third_party/tensorflow_text_remove_tf_deps.patch
new file mode 100644
index 00000000..f7b86f9f
--- /dev/null
+++ b/third_party/tensorflow_text_remove_tf_deps.patch
@@ -0,0 +1,32 @@
+diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD
+index bdca365..1c20eae 100644
+--- a/tensorflow_text/core/kernels/BUILD
++++ b/tensorflow_text/core/kernels/BUILD
+@@ -209,8 +209,12 @@ cc_library(
+ name = "regex_split",
+ srcs = ["regex_split.cc"],
+ hdrs = ["regex_split.h"],
+- deps = OSS_DEPS + [
++ deps = [
+ # absl/strings dep
++ "@com_google_absl//absl/container:inlined_vector",
++ "@com_google_absl//absl/strings",
++ "@com_google_absl//absl/types:optional",
++ "@com_google_absl//absl/types:span",
+ "@com_google_re2//:re2",
+ ],
+ )
+@@ -437,8 +441,12 @@ cc_library(
+ name = "wordpiece_tokenizer",
+ srcs = ["wordpiece_tokenizer.cc"],
+ hdrs = ["wordpiece_tokenizer.h"],
+- deps = OSS_DEPS + [
++ deps = [
+ # absl/strings dep
++ "@com_google_absl//absl/container:inlined_vector",
++ "@com_google_absl//absl/strings",
++ "@com_google_absl//absl/types:optional",
++ "@com_google_absl//absl/types:span",
+ "@icu//:common",
+ ],
+ ) \ No newline at end of file
diff --git a/third_party/toolchains/java/BUILD b/third_party/toolchains/java/BUILD
new file mode 100644
index 00000000..83722915
--- /dev/null
+++ b/third_party/toolchains/java/BUILD
@@ -0,0 +1,18 @@
+# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1
+# TensorFlow still targets Java 1.7 (See JAVACOPTS in tensorflow/java/build_defs.bzl)
+# which doesn't support "-parameters" flag. Starting from Java 11 (default since Bazel
+# 0.29.1), a warning message will be thrown if "-parameters" is passed. If "-Werror" also exists,
+# the compiling action will fail. To workaround this, we override the misc value of
+# the default java toolchain to remove "-parameters" flag.
+load("@bazel_tools//tools/jdk:default_java_toolchain.bzl", "default_java_toolchain")
+
+licenses(["notice"])
+
+default_java_toolchain(
+ name = "tf_java_toolchain",
+ misc = [
+ "-XDskipDuplicateBridges=true",
+ "-g",
+ ],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/utf.BUILD b/third_party/utf.BUILD
new file mode 100644
index 00000000..0a21fd78
--- /dev/null
+++ b/third_party/utf.BUILD
@@ -0,0 +1,38 @@
+cc_library(
+ name = "utf",
+ srcs = [
+ "libutf/rune.c",
+ "libutf/runestrcat.c",
+ "libutf/runestrchr.c",
+ "libutf/runestrcmp.c",
+ "libutf/runestrcpy.c",
+ "libutf/runestrdup.c",
+ "libutf/runestrecpy.c",
+ "libutf/runestrlen.c",
+ "libutf/runestrncat.c",
+ "libutf/runestrncmp.c",
+ "libutf/runestrncpy.c",
+ "libutf/runestrrchr.c",
+ "libutf/runestrstr.c",
+ "libutf/runetype.c",
+ "libutf/utfecpy.c",
+ "libutf/utflen.c",
+ "libutf/utfnlen.c",
+ "libutf/utfrrune.c",
+ "libutf/utfrune.c",
+ "libutf/utfutf.c",
+ ],
+ hdrs = [
+ "libutf/plan9.h",
+ "libutf/utf.h",
+ "libutf/utfdef.h",
+ ],
+ copts = [
+ "-Wno-parentheses",
+ ],
+ includes = [
+ ".",
+ "libutf",
+ ],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD
new file mode 100644
index 00000000..275782e0
--- /dev/null
+++ b/third_party/zlib.BUILD
@@ -0,0 +1,39 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])
+
+cc_library(
+ name = "zlib",
+ srcs = [
+ "adler32.c",
+ "compress.c",
+ "crc32.c",
+ "crc32.h",
+ "deflate.c",
+ "deflate.h",
+ "gzclose.c",
+ "gzguts.h",
+ "gzlib.c",
+ "gzread.c",
+ "gzwrite.c",
+ "infback.c",
+ "inffast.c",
+ "inffast.h",
+ "inffixed.h",
+ "inflate.c",
+ "inflate.h",
+ "inftrees.c",
+ "inftrees.h",
+ "trees.c",
+ "trees.h",
+ "uncompr.c",
+ "zutil.c",
+ "zutil.h",
+ ],
+ hdrs = [
+ "zconf.h",
+ "zlib.h",
+ ],
+ copts = ["-Wno-implicit-function-declaration"],
+ includes = ["."],
+)