diff options
author | Tony Mak <tonymak@google.com> | 2021-02-17 12:01:24 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2021-02-17 12:01:24 +0000 |
commit | 699e7bd5f0f6f12005fc96592c5b82e692dcfc56 (patch) | |
tree | dbbe4fa5f2c1115321652b1376792217f94739cb | |
parent | ed73f2f75b75da3675b0a984dbfd195702855ed1 (diff) | |
parent | 815025dfce13a988896874b2f182de7c0e6cbad1 (diff) | |
download | tflite-support-699e7bd5f0f6f12005fc96592c5b82e692dcfc56.tar.gz |
Import platform/external/tflite-support am: ff3e0f1735 am: 815025dfce
Original change: https://android-review.googlesource.com/c/platform/external/tflite-support/+/1590212
MUST ONLY BE SUBMITTED BY AUTOMERGER
Change-Id: I8b292b6b90a02c93617557e2822e8251c36affec
412 files changed, 43507 insertions, 0 deletions
diff --git a/.bazelrc b/.bazelrc new file mode 100644 index 00000000..bdc2c07d --- /dev/null +++ b/.bazelrc @@ -0,0 +1,170 @@ +# This file is based on tensorflow's (v2.2.0) .bazelrc found here: +# https://github.com/tensorflow/tensorflow/blob/v2.2.0/.bazelrc + +# Sets the default Apple platform to macOS. + +build --apple_platform_type=macos + +# Enable using platform specific build settings +build --enable_platform_specific_config + +# Flag to enable remote config. Required starting from TF 2.2. +common --experimental_repo_remote_exec + +# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1 +build --java_toolchain=//third_party/toolchains/java:tf_java_toolchain +build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain + +# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs. +build:android --copt=-w +build:linux --copt=-w +build:macos --copt=-w +build:windows --copt=/w + +# Android workspace configurations. Should be replaced by an interative configure in the future. +build --action_env ANDROID_NDK_HOME +build --action_env ANDROID_NDK_API_LEVEL +build --action_env ANDROID_BUILD_TOOLS_VERSION +build --action_env ANDROID_SDK_API_LEVEL +build --action_env ANDROID_SDK_HOME + +# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the +# target CPU to build transient dependencies correctly. See +# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu + +build:android --crosstool_top=//external:android/crosstool +build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain +build:android_arm --config=android +build:android_arm --cpu=armeabi-v7a +build:android_arm --fat_apk_cpu=armeabi-v7a +build:android_arm64 --config=android +build:android_arm64 --cpu=arm64-v8a +build:android_arm64 --fat_apk_cpu=arm64-v8a +build:android_x86 --config=android +build:android_x86 --cpu=x86 +build:android_x86 --fat_apk_cpu=x86 +build:android_x86_64 --config=android +build:android_x86_64 --cpu=x86_64 +build:android_x86_64 --fat_apk_cpu=x86_64 + +# iOS configs for each architecture and the fat binary builds. +build:ios --apple_platform_type=ios +build:ios --apple_bitcode=embedded --copt=-fembed-bitcode +build:ios --copt=-Wno-c++11-narrowing +build:ios_armv7 --config=ios +build:ios_armv7 --cpu=ios_armv7 +build:ios_arm64 --config=ios +build:ios_arm64 --cpu=ios_arm64 +build:ios_x86_64 --config=ios +build:ios_x86_64 --cpu=ios_x86_64 +build:ios_fat --config=ios +build:ios_fat --ios_multi_cpus=armv7,arm64,x86_64 + +# By default, build TF in C++ 14 mode. +build:android --cxxopt=-std=c++14 +build:android --host_cxxopt=-std=c++14 +build:ios --cxxopt=-std=c++14 +build:ios --host_cxxopt=-std=c++14 +build:linux --cxxopt=-std=c++14 +build:linux --host_cxxopt=-std=c++14 +build:macos --cxxopt=-std=c++14 +build:macos --host_cxxopt=-std=c++14 +build:windows --cxxopt=/std:c++14 +build:windows --host_cxxopt=/std:c++14 + +# Config to use a mostly-static build and disable modular op registration +# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python). +# By default, TensorFlow will build with a dependence on +# //tensorflow:libtensorflow_framework.so. +build:monolithic --define framework_shared_object=false + +# For projects which use TensorFlow as part of a Bazel build process, putting +# nothing in a bazelrc will default to a monolithic build. The following line +# opts in to modular op registration support by default. +build --define framework_shared_object=true + +# ASAN build +build:asan --strip=never +build:asan --copt -fsanitize=address +build:asan --copt -DADDRESS_SANITIZER +build:asan --copt -O1 +build:asan --copt -g +build:asan --copt -fno-omit-frame-pointer +build:asan --linkopt -fsanitize=address + +# Flags for open source build, always set to be true. +build --define open_source_build=true +test --define open_source_build=true + +# dbg config, as a shorthand for '--config=opt -c dbg' +build:dbg --config=opt -c dbg +# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360 +build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON +# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 +build:dbg --copt -DDEBUG_BUILD + +build --define=use_fast_cpp_protos=true +build --define=allow_oversize_protos=true + +build --spawn_strategy=standalone +build -c opt + +# Adding "--cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" creates parity with TF +# compilation options. It also addresses memory use due to +# copy-on-write semantics of std::strings of the older ABI. +build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0 + +# Make Bazel print out all options from rc files. +build --announce_rc + +# Other build flags. +build --define=grpc_no_ares=true + +# See https://github.com/bazelbuild/bazel/issues/7362 for information on what +# --incompatible_remove_legacy_whole_archive flag does. +# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate +# Tensorflow to the default, however test coverage wasn't enough to catch the +# errors. +# There is ongoing work on Bazel team's side to provide support for transitive +# shared libraries. As part of migrating to transitive shared libraries, we +# hope to provide a better mechanism for control over symbol exporting, and +# then tackle this issue again. +# +# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library +# archives in -whole_archive -no_whole_archive. +build --noincompatible_remove_legacy_whole_archive + +# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0 +# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC: +# https://github.com/tensorflow/community/pull/179 +build --noincompatible_prohibit_aapt1 + +# Build TF with C++ 17 features. +build:c++17 --cxxopt=-std=c++1z +build:c++17 --cxxopt=-stdlib=libc++ +build:c++1z --config=c++17 + +# Enable using platform specific build settings, except when cross-compiling for +# mobile platforms. +build --enable_platform_specific_config +build:android --noenable_platform_specific_config +build:ios --noenable_platform_specific_config + +# Suppress all warning messages. +build:short_logs --output_filter=DONT_MATCH_ANYTHING +build:verbose_logs --output_filter= +build --config=short_logs + +# Options to build TensorFlow 1.x or 2.x. +build:v1 --define=tf_api_version=1 +build:v2 --define=tf_api_version=2 +build:v1 --action_env=TF2_BEHAVIOR=0 +build:v2 --action_env=TF2_BEHAVIOR=1 +build --config=v2 +test --config=v2 + +# Options from ./configure +try-import %workspace%/.tf_configure.bazelrc + +# Put user-specific options in .bazelrc.user +try-import %workspace%/.bazelrc.user @@ -0,0 +1,5 @@ +exports_files( + [ + "LICENSE", + ], +) diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..786bd073 --- /dev/null +++ b/LICENSE @@ -0,0 +1,203 @@ +Copyright 2020 The TensorFlow Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/METADATA b/METADATA new file mode 100644 index 00000000..fb2db712 --- /dev/null +++ b/METADATA @@ -0,0 +1,19 @@ +name: "tflite-support" +description: + "TFLite Support is a toolkit that helps users to develop ML and deploy " + "TFLite models onto mobile devices. It works cross-Platform and is " + "supported on Java, C++ (WIP), and Swift (WIP)." + +third_party { + url { + type: HOMEPAGE + value: "https://github.com/tensorflow/tflite-support" + } + url { + type: GIT + value: "https://github.com/tensorflow/tflite-support" + } + version: "v0.1.0" + last_upgrade_date { year: 2021 month: 1 day: 14 } + license_type: NOTICE +}
\ No newline at end of file diff --git a/MODULE_LICENSE_APACHE2 b/MODULE_LICENSE_APACHE2 new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/MODULE_LICENSE_APACHE2 diff --git a/README.md b/README.md new file mode 100644 index 00000000..67d4a8fd --- /dev/null +++ b/README.md @@ -0,0 +1,62 @@ +# TensorFlow Lite Support + +TFLite Support is a toolkit that helps users to develop ML and deploy TFLite +models onto mobile devices. It works cross-Platform and is supported on Java, +C++ (WIP), and Swift (WIP). The TFLite Support project consists of the following +major components: + +* **TFLite Support Library**: a cross-platform library that helps to + deploy TFLite models onto mobile devices. +* **TFLite Model Metadata**: (metadata populator and metadata extractor + library): includes both human and machine readable information about what a + model does and how to use the model. +* **TFLite Support Codegen Tool**: an executable that generates model wrapper + automatically based on the Support Library and the metadata. +* **TFLite Support Task Library**: a flexible and ready-to-use library for + common machine learning model types, such as classification and detection, + client can also build their own native/Android/iOS inference API on Task + Library infra. + +TFLite Support library serves different tiers of deployment requirements from +easy onboarding to fully customizable. There are three major use cases that +TFLite Support targets at: + +* **Provide ready-to-use APIs for users to interact with the model**. \ + This is achieved by the TFLite Support Codegen tool, where users can get the + model interface (contains ready-to-use APIs) simply by passing the model to + the codegen tool. The automatic codegen strategy is designed based on the + TFLite metadata. + +* **Provide optimized model interface for popular ML tasks**. \ + The model interfaces provided by the TFLite Support Task Library are + specifically optimized compared to the codegen version in terms of both + usability and performance. Users can also swap their own custom models with + the default models in each task. + +* **Provide the flexibility to customize model interface and build inference + pipelines**. \ + The TFLite Support Util Library contains varieties of util methods and data + structures to perform pre/post processing and data conversion. It is also + designed to match the behavior of TensorFlow modules, such as TF.Image and + TF.text, ensuring consistency from training to inferencing. + +See the +[documentation on tensorflow.org](https://www.tensorflow.org/lite/inference_with_metadata/overview) +for more instruction and examples. + +## Build Instructions + +We use Bazel to build the project. When you're building the Java (Android) +Utils, you need to set up following env variables correctly: + +* `ANDROID_NDK_HOME` +* `ANDROID_SDK_HOME` +* `ANDROID_NDK_API_LEVEL` +* `ANDROID_SDK_API_LEVEL` +* `ANDROID_BUILD_TOOLS_VERSION` + +## Contact us + +Let us know what you think about TFLite Support by creating a +[new Github issue](https://github.com/tensorflow/tflite-support/issues/new), or +email us at tflite-support-team@google.com. diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 00000000..21948710 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,384 @@ +workspace(name = "org_tensorflow_lite_support") + +load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@//third_party/py:python_configure.bzl", "python_configure") + +http_archive( + name = "io_bazel_rules_closure", + sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9", + strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13 + ], +) + +# Apple and Swift rules. +# https://github.com/bazelbuild/rules_apple/releases +http_archive( + name = "build_bazel_rules_apple", + sha256 = "ee9e6073aeb5a65c100cb9c44b0017c937706a4ae03176e14a7e78620a198079", + strip_prefix = "rules_apple-5131f3d46794bf227d296c82f30c2499c9de3c5b", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_apple/archive/5131f3d46794bf227d296c82f30c2499c9de3c5b.tar.gz", + "https://github.com/bazelbuild/rules_apple/archive/5131f3d46794bf227d296c82f30c2499c9de3c5b.tar.gz", + ], +) + +# https://github.com/bazelbuild/rules_swift/releases +http_archive( + name = "build_bazel_rules_swift", + sha256 = "d0833bc6dad817a367936a5f902a0c11318160b5e80a20ece35fb85a5675c886", + strip_prefix = "rules_swift-3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_swift/archive/3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8.tar.gz", + "https://github.com/bazelbuild/rules_swift/archive/3eeeb53cebda55b349d64c9fc144e18c5f7c0eb8.tar.gz", + ], +) + +# tf-nightly-20200810 +http_archive( + name = "org_tensorflow", + sha256 = "26c833b7e1873936379e810a39d14700281125257ddda8cd822c89111db6f6ae", + strip_prefix = "tensorflow-2.4.0", + urls = [ + "https://github.com/tensorflow/tensorflow/archive/v2.4.0.tar.gz", + ], + patches = ["@//third_party:tensorflow_lite_ios_build.patch"], + patch_args = ["-p1"], +) + +# Set up dependencies. Need to do this before set up TF so that our modification +# could take effects. +load("//third_party:repo.bzl", "third_party_http_archive") + +# Use our patched gflags which fixes a linking issue. +load("//third_party/gflags:workspace.bzl", gflags = "repo") +gflags() + +third_party_http_archive( + name = "pybind11", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.6.0.tar.gz", + "https://github.com/pybind/pybind11/archive/v2.6.0.tar.gz", + ], + sha256 = "90b705137b69ee3b5fc655eaca66d0dc9862ea1759226f7ccd3098425ae69571", + strip_prefix = "pybind11-2.6.0", + build_file = "//third_party:pybind11.BUILD", +) + +http_archive( + name = "absl_py", + sha256 = "603febc9b95a8f2979a7bdb77d2f5e4d9b30d4e0d59579f88eba67d4e4cc5462", + strip_prefix = "abseil-py-pypi-v0.9.0", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/abseil/abseil-py/archive/pypi-v0.9.0.tar.gz", + "https://github.com/abseil/abseil-py/archive/pypi-v0.9.0.tar.gz", + ], +) + +http_archive( + name = "six_archive", + build_file = "//third_party:six.BUILD", + sha256 = "d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73", + strip_prefix = "six-1.12.0", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz", + "https://pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz", + ], +) + +http_archive( + name = "com_google_sentencepiece", + strip_prefix = "sentencepiece-1.0.0", + sha256 = "c05901f30a1d0ed64cbcf40eba08e48894e1b0e985777217b7c9036cac631346", + urls = [ + "https://github.com/google/sentencepiece/archive/1.0.0.zip", + ], +) + +http_archive( + name = "org_tensorflow_text", + sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8", + strip_prefix = "text-2.2.0", + urls = [ + "https://github.com/tensorflow/text/archive/v2.2.0.zip", + ], + patches = ["@//third_party:tensorflow_text_remove_tf_deps.patch"], + patch_args = ["-p1"], + repo_mapping = {"@com_google_re2": "@com_googlesource_code_re2"}, +) + +http_archive( + name = "com_googlesource_code_re2", + sha256 = "d070e2ffc5476c496a6a872a6f246bfddce8e7797d6ba605a7c8d72866743bf9", + strip_prefix = "re2-506cfa4bffd060c06ec338ce50ea3468daa6c814", + urls = [ + "https://github.com/google/re2/archive/506cfa4bffd060c06ec338ce50ea3468daa6c814.tar.gz", + ], +) + +# ABSL cpp library lts_2020_02_25 +# Needed for absl/status +http_archive( + name = "com_google_absl", + build_file = "//third_party:com_google_absl.BUILD", + urls = [ + "https://github.com/abseil/abseil-cpp/archive/20200225.tar.gz", + ], + # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. + patches = [ + "@//third_party:com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff" + ], + patch_args = [ + "-p1", + ], + strip_prefix = "abseil-cpp-20200225", + sha256 = "728a813291bdec2aa46eab8356ace9f75ac2ed9dfe2df5ab603c4e6c09f1c353" +) + +http_archive( + name = "com_google_glog", + sha256 = "1ee310e5d0a19b9d584a855000434bb724aa744745d5b8ab1855c85bff8a8e21", + strip_prefix = "glog-028d37889a1e80e8a07da1b8945ac706259e5fd8", + urls = [ + "https://mirror.bazel.build/github.com/google/glog/archive/028d37889a1e80e8a07da1b8945ac706259e5fd8.tar.gz", + "https://github.com/google/glog/archive/028d37889a1e80e8a07da1b8945ac706259e5fd8.tar.gz", + ], +) + + +http_archive( + name = "zlib", + build_file = "//third_party:zlib.BUILD", + sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", + strip_prefix = "zlib-1.2.11", + urls = [ + "http://mirror.bazel.build/zlib.net/fossils/zlib-1.2.11.tar.gz", + "http://zlib.net/fossils/zlib-1.2.11.tar.gz", # 2017-01-15 + ], +) + +http_archive( + name = "org_libzip", + build_file = "//third_party:libzip.BUILD", + sha256 = "a5d22f0c87a2625450eaa5e10db18b8ee4ef17042102d04c62e311993a2ba363", + strip_prefix = "libzip-rel-1-5-1", + urls = [ + # Bazel does not like the official download link at libzip.org, + # so use the GitHub release tag. + "https://mirror.bazel.build/github.com/nih-at/libzip/archive/rel-1-5-1.zip", + "https://github.com/nih-at/libzip/archive/rel-1-5-1.zip", + ], +) + +http_archive( + name = "libyuv", + urls = ["https://chromium.googlesource.com/libyuv/libyuv/+archive/6d603ec3f57dafddc424ef895e5d903915e94ba6.tar.gz"], + # Adding the constrain of sha256 and strip_prefix will cause failure. + # It seems that the downloaded libyuv was different every time, so that + # the specified sha256 and strip_prefix cannot match. + # sha256 = "ce196c72858456baa8022fa4a0dc18b77d619265dbc0e3d58e25ad15ca402522", + # strip_prefix = "libyuv-6d603ec3f57dafddc424ef895e5d903915e94ba6", + build_file = "//third_party:libyuv.BUILD", +) + +http_archive( + name = "stblib", + strip_prefix = "stb-b42009b3b9d4ca35bc703f5310eedc74f584be58", + sha256 = "13a99ad430e930907f5611325ec384168a958bf7610e63e60e2fd8e7b7379610", + urls = ["https://github.com/nothings/stb/archive/b42009b3b9d4ca35bc703f5310eedc74f584be58.tar.gz"], + build_file = "//third_party:stblib.BUILD", +) + +http_archive( + name = "google_toolbox_for_mac", + url = "https://github.com/google/google-toolbox-for-mac/archive/v2.2.1.zip", + sha256 = "e3ac053813c989a88703556df4dc4466e424e30d32108433ed6beaec76ba4fdc", + strip_prefix = "google-toolbox-for-mac-2.2.1", + build_file = "@//third_party:google_toolbox_for_mac.BUILD", +) + +http_archive( + name = "utf_archive", + build_file = "@//third_party:utf.BUILD", + sha256 = "262a902f622dcd28e05b8a4be10da0aa3899050d0be8f4a71780eed6b2ea65ca", + urls = [ + "https://mirror.bazel.build/9fans.github.io/plan9port/unix/libutf.tgz", + "https://9fans.github.io/plan9port/unix/libutf.tgz", + ], +) + +http_archive( + name = "icu", + strip_prefix = "icu-release-64-2", + sha256 = "dfc62618aa4bd3ca14a3df548cd65fe393155edd213e49c39f3a30ccd618fc27", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/unicode-org/icu/archive/release-64-2.zip", + "https://github.com/unicode-org/icu/archive/release-64-2.zip", + ], + build_file = "@//third_party:icu.BUILD", +) + +http_archive( + name = "fft2d", + build_file = "@//third_party/fft2d:fft2d.BUILD", + sha256 = "5f4dabc2ae21e1f537425d58a49cdca1c49ea11db0d6271e2a4b27e9697548eb", + strip_prefix = "OouraFFT-1.0", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/petewarden/OouraFFT/archive/v1.0.tar.gz", + "https://github.com/petewarden/OouraFFT/archive/v1.0.tar.gz", + ], +) + +http_archive( + name = "darts_clone", + build_file = "@//third_party:darts_clone.BUILD", + sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c", + strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983", + urls = [ + "https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip", + ], +) + +http_archive( + name = "com_google_protobuf", + sha256 = "a79d19dcdf9139fa4b81206e318e33d245c4c9da1ffed21c87288ed4380426f9", + strip_prefix = "protobuf-3.11.4", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.11.4.tar.gz"], + patches = [ + "@//third_party:com_google_protobuf_fixes.diff" + ], + patch_args = [ + "-p1", + ], +) + +# AutoValue 1.6+ shades Guava, Auto Common, and JavaPoet. That's OK +# because none of these jars become runtime dependencies. +java_import_external( + name = "com_google_auto_value", + jar_sha256 = "fd811b92bb59ae8a4cf7eb9dedd208300f4ea2b6275d726e4df52d8334aaae9d", + jar_urls = [ + "https://mirror.bazel.build/repo1.maven.org/maven2/com/google/auto/value/auto-value/1.6/auto-value-1.6.jar", + "https://repo1.maven.org/maven2/com/google/auto/value/auto-value/1.6/auto-value-1.6.jar", + ], + licenses = ["notice"], # Apache 2.0 + generated_rule_name = "processor", + exports = ["@com_google_auto_value_annotations"], + extra_build_file_content = "\n".join([ + "java_plugin(", + " name = \"AutoAnnotationProcessor\",", + " output_licenses = [\"unencumbered\"],", + " processor_class = \"com.google.auto.value.processor.AutoAnnotationProcessor\",", + " tags = [\"annotation=com.google.auto.value.AutoAnnotation;genclass=${package}.AutoAnnotation_${outerclasses}${classname}_${methodname}\"],", + " deps = [\":processor\"],", + ")", + "", + "java_plugin(", + " name = \"AutoOneOfProcessor\",", + " output_licenses = [\"unencumbered\"],", + " processor_class = \"com.google.auto.value.processor.AutoOneOfProcessor\",", + " tags = [\"annotation=com.google.auto.value.AutoValue;genclass=${package}.AutoOneOf_${outerclasses}${classname}\"],", + " deps = [\":processor\"],", + ")", + "", + "java_plugin(", + " name = \"AutoValueProcessor\",", + " output_licenses = [\"unencumbered\"],", + " processor_class = \"com.google.auto.value.processor.AutoValueProcessor\",", + " tags = [\"annotation=com.google.auto.value.AutoValue;genclass=${package}.AutoValue_${outerclasses}${classname}\"],", + " deps = [\":processor\"],", + ")", + "", + "java_library(", + " name = \"com_google_auto_value\",", + " exported_plugins = [", + " \":AutoAnnotationProcessor\",", + " \":AutoOneOfProcessor\",", + " \":AutoValueProcessor\",", + " ],", + " exports = [\"@com_google_auto_value_annotations\"],", + ")", + ]), +) + +# Auto value annotations +java_import_external( + name = "com_google_auto_value_annotations", + jar_sha256 = "d095936c432f2afc671beaab67433e7cef50bba4a861b77b9c46561b801fae69", + jar_urls = [ + "https://mirror.bazel.build/repo1.maven.org/maven2/com/google/auto/value/auto-value-annotations/1.6/auto-value-annotations-1.6.jar", + "https://repo1.maven.org/maven2/com/google/auto/value/auto-value-annotations/1.6/auto-value-annotations-1.6.jar", + ], + licenses = ["notice"], # Apache 2.0 + neverlink = True, + default_visibility = ["@com_google_auto_value//:__pkg__"], +) + +load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") + +flatbuffers() +# Set up TF. +load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace") +tf_workspace(tf_repo_name="@org_tensorflow") + +load("//third_party/tensorflow:tf_configure.bzl", "tf_configure") +tf_configure(name = "local_config_tf") + +# TF submodule compilation doesn't take care of grpc deps. Do it manually here. +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +grpc_deps() + +load( + "@build_bazel_rules_apple//apple:repositories.bzl", + "apple_rules_dependencies", +) +apple_rules_dependencies() + +load( + "@build_bazel_apple_support//lib:repositories.bzl", + "apple_support_dependencies", +) +apple_support_dependencies() + +load("@upb//bazel:repository_defs.bzl", "bazel_version_repository") +bazel_version_repository(name = "bazel_version") + + +# Set up Android. +load("//third_party/android:android_configure.bzl", "android_configure") +android_configure(name="local_config_android") +load("@local_config_android//:android.bzl", "android_workspace") +android_workspace() + +python_configure(name = "local_config_python") + + +# Maven dependencies. + +RULES_JVM_EXTERNAL_TAG = "3.2" + +http_archive( + name = "rules_jvm_external", + strip_prefix = "rules_jvm_external-%s" % RULES_JVM_EXTERNAL_TAG, + sha256 = "82262ff4223c5fda6fb7ff8bd63db8131b51b413d26eb49e3131037e79e324af", + url = "https://github.com/bazelbuild/rules_jvm_external/archive/%s.zip" % RULES_JVM_EXTERNAL_TAG, +) + +load("@rules_jvm_external//:defs.bzl", "maven_install") + +maven_install( + artifacts = [ + "androidx.annotation:annotation:aar:1.1.0", + ], + repositories = [ + "https://jcenter.bintray.com", + "https://maven.google.com", + "https://dl.google.com/dl/android/maven2", + "https://repo1.maven.org/maven2", + ], + fetch_sources = True, + version_conflict_policy = "pinned", +) diff --git a/tensorflow_lite_support/BUILD b/tensorflow_lite_support/BUILD new file mode 100644 index 00000000..f123f1f2 --- /dev/null +++ b/tensorflow_lite_support/BUILD @@ -0,0 +1,38 @@ +# TFLite Support is a toolkit that helps users to develop ML and deploy TFLite +# models onto mobile devices. + +package( + default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) + +# LINT.IfChange +package_group( + name = "users", + packages = [ + # tensorflow_examples/... dep, + "//tensorflow_lite_support/...", + "//third_party/tensorflow_models/...", + ], +) +# Remove internal path from tensorflow_lite_support:users in the copybara file. +# LINT.ThenChange(//tensorflow_lite_support/copy.bara.sky) + +# Config setting for determining if we are building for Android. +config_setting( + name = "android", + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +# Config setting for determining if we are building for macos. +config_setting( + name = "macos", + values = { + "apple_platform_type": "macos", + "cpu": "darwin", + }, + visibility = ["//visibility:public"], +) diff --git a/tensorflow_lite_support/cc/BUILD b/tensorflow_lite_support/cc/BUILD new file mode 100644 index 00000000..b19bfdec --- /dev/null +++ b/tensorflow_lite_support/cc/BUILD @@ -0,0 +1,25 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "common", + srcs = [ + "common.cc", + ], + hdrs = ["common.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +config_setting( + name = "tflite_use_c_api", + values = { + "copt": "-DTFLITE_USE_C_API", + }, + visibility = ["//tensorflow_lite_support:__subpackages__"], +) diff --git a/tensorflow_lite_support/cc/common.cc b/tensorflow_lite_support/cc/common.cc new file mode 100644 index 00000000..47dd3bcc --- /dev/null +++ b/tensorflow_lite_support/cc/common.cc @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/common.h" + +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" + +namespace tflite { +namespace support { + +absl::Status CreateStatusWithPayload(absl::StatusCode canonical_code, + absl::string_view message, + TfLiteSupportStatus tfls_code) { + // NOTE: Ignores `message` if the canonical code is ok. + absl::Status status = absl::Status(canonical_code, message); + // NOTE: Does nothing if the canonical code is ok. + status.SetPayload(kTfLiteSupportPayload, absl::Cord(absl::StrCat(tfls_code))); + return status; +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/common.h b/tensorflow_lite_support/cc/common.h new file mode 100644 index 00000000..50c8dc40 --- /dev/null +++ b/tensorflow_lite_support/cc/common.h @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace tflite { +namespace support { + +// Name (aka type URL key) of the `absl::Status` payload which contains a +// stringified `TfLiteSupportStatus` code (see below). +constexpr absl::string_view kTfLiteSupportPayload = + "tflite::support::TfLiteSupportStatus"; + +// Error codes for TensorFlow Lite Support (TFLS) C++ APIs. +// +// Such codes capture errors encountered in the TFLS layer. They complement all +// the other type of errors that occur in the lower-level TF Lite codebase (see +// `TfLiteStatus` codes). +// +// At runtime, such codes are meant to be attached (where applicable) to a +// `absl::Status` in a key-value manner with `kTfLiteSupportPayload` as key and +// stringifed error code as value (aka payload). This logic is encapsulated in +// the `CreateStatusWithPayload` helper below for convenience. +// +// The returned status includes: +// 1. The canonical error code (INVALID_ARGUMENT) +// 2. The fine-grained error message ("Invalid metadata ...") +// 3. The specific TFLS code as a payload (kMetadataInvalidSchemaVersionError) +enum class TfLiteSupportStatus { + // Generic error codes. + + // Success. + kOk = 0, + // Unspecified error. + kError = 1, + // Invalid argument specified. + kInvalidArgumentError = 2, + // Invalid FlatBuffer file or buffer specified. + kInvalidFlatBufferError = 3, + + // File I/O error codes. + + // No such file. + kFileNotFoundError = 100, + // Permission issue. + kFilePermissionDeniedError, + // I/O error when reading file. + kFileReadError, + // I/O error when mmap-ing file. + kFileMmapError, + + // TensorFlow Lite metadata error codes. + + // Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer. + kMetadataInvalidSchemaVersionError = 200, + // No such associated file within metadata, or file has not been packed. + kMetadataAssociatedFileNotFoundError, + // ZIP I/O error when unpacking an associated file. + kMetadataAssociatedFileZipError, + // Inconsistency error between the metadata and actual TF Lite model. + // E.g.: number of labels and output tensor values differ. + kMetadataInconsistencyError, + // Invalid process units specified. + // E.g.: multiple ProcessUnits with the same type for a given tensor. + kMetadataInvalidProcessUnitsError, + // Inconsistency error with the number of labels. + // E.g.: label files for different locales have a different number of labels. + kMetadataNumLabelsMismatchError, + // Score calibration parameters parsing error. + // E.g.: too many parameters provided in the corresponding associated file. + kMetadataMalformedScoreCalibrationError, + // Unexpected number of subgraphs for the current task. + // E.g.: image classification expects a single subgraph. + kMetadataInvalidNumSubgraphsError, + // A given tensor requires NormalizationOptions but none were found. + // E.g.: float input tensor requires normalization to preprocess input images. + kMetadataMissingNormalizationOptionsError, + // Invalid ContentProperties specified. + // E.g. expected ImageProperties, got BoundingBoxProperties. + kMetadataInvalidContentPropertiesError, + // Metadata is mandatory but was not found. + // E.g. current task requires TFLite Model Metadata but none was found. + kMetadataNotFoundError, + // Associated TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS file is mandatory but + // none was found or it was empty. + // E.g. current task requires labels but none were found. + kMetadataMissingLabelsError, + // The ProcessingUnit for tokenizer is not correctly configured. + // E.g BertTokenizer doesn't have a valid vocab file associated. + kMetadataInvalidTokenizerError, + + // Input tensor(s) error codes. + + // Unexpected number of input tensors for the current task. + // E.g. current task expects a single input tensor. + kInvalidNumInputTensorsError = 300, + // Unexpected input tensor dimensions for the current task. + // E.g.: only 4D input tensors supported. + kInvalidInputTensorDimensionsError, + // Unexpected input tensor type for the current task. + // E.g.: current task expects a uint8 pixel image as input. + kInvalidInputTensorTypeError, + // Unexpected input tensor bytes size. + // E.g.: size in bytes does not correspond to the expected number of pixels. + kInvalidInputTensorSizeError, + // No correct input tensor found for the model. + // E.g.: input tensor name is not part of the text model's input tensors. + kInputTensorNotFoundError, + + // Output tensor(s) error codes. + + // Unexpected output tensor dimensions for the current task. + // E.g.: only a batch size of 1 is supported. + kInvalidOutputTensorDimensionsError = 400, + // Unexpected input tensor type for the current task. + // E.g.: multi-head model with different output tensor types. + kInvalidOutputTensorTypeError, + // No correct output tensor found for the model. + // E.g.: output tensor name is not part of the text model's output tensors. + kOutputTensorNotFoundError, + // Unexpected number of output tensors for the current task. + // E.g.: current task expects a single output tensor. + kInvalidNumOutputTensorsError, + + // Image processing error codes. + + // Unspecified image processing failures. + kImageProcessingError = 500, + // Unexpected input or output buffer metadata. + // E.g.: rotate RGBA buffer to Grayscale buffer by 90 degrees. + kImageProcessingInvalidArgumentError, + // Image processing operation failures. + // E.g. libyuv rotation failed for an unknown reason. + kImageProcessingBackendError, +}; + +// Convenience helper to create an `absl::Status` augmented with the +// fine-grained `tfls_code` attached as payload under the +// `kTfLiteSupportPayload` type URL key. +// +// This should only be used for non-ok codes since otherwise it does nothing +// more than returning an object identical to an OK status. See `absl::Status` +// for more details. +absl::Status CreateStatusWithPayload( + absl::StatusCode canonical_code, absl::string_view message, + tflite::support::TfLiteSupportStatus tfls_code = + tflite::support::TfLiteSupportStatus::kError); + +} // namespace support +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_COMMON_H_ diff --git a/tensorflow_lite_support/cc/port/BUILD b/tensorflow_lite_support/cc/port/BUILD new file mode 100644 index 00000000..195d5a11 --- /dev/null +++ b/tensorflow_lite_support/cc/port/BUILD @@ -0,0 +1,73 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "statusor", + hdrs = [ + "statusor.h", + ], + deps = [ + "//tensorflow_lite_support/cc/port/default:statusor", + ], +) + +cc_library( + name = "status_macros", + hdrs = [ + "status_macros.h", + ], + deps = [ + "//tensorflow_lite_support/cc/port/default:status_macros", + ], +) + +cc_library( + name = "tflite_wrapper", + hdrs = ["tflite_wrapper.h"], + deps = ["//tensorflow_lite_support/cc/port/default:tflite_wrapper"], +) + +# This is identical to the rule above, except that it gets built with +# '-DTFLITE_USE_C_API'. This rule is used for unit tests that verify things +# work correctly when built with TFLITE_USE_C_API defined. +cc_library( + name = "tflite_wrapper_with_c_api_for_test", + testonly = 1, + hdrs = ["tflite_wrapper.h"], + deps = [ + "//intelligence/mobile_acceleration/proto:allowlist_portable_proto", + "//intelligence/mobile_acceleration/support_library:tflite_wrapper_with_c_api_for_test", + ], +) + +cc_library( + name = "integral_types", + hdrs = ["integral_types.h"], +) + +cc_library( + name = "gtest", + testonly = 1, + hdrs = [ + "gmock.h", + "gtest.h", + ], + deps = [ + "//testing/base/public:gunit_for_library_testonly", + ], +) + +cc_library( + name = "gtest_main", + testonly = 1, + hdrs = [ + "benchmark.h", + "gmock.h", + "gtest.h", + ], + deps = [ + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow_lite_support/cc/port/benchmark.h b/tensorflow_lite_support/cc/port/benchmark.h new file mode 100644 index 00000000..74bc1a68 --- /dev/null +++ b/tensorflow_lite_support/cc/port/benchmark.h @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ + +#include "gtest/benchmark.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_BENCHMARK_H_ diff --git a/tensorflow_lite_support/cc/port/build_defs.bzl b/tensorflow_lite_support/cc/port/build_defs.bzl new file mode 100644 index 00000000..a8053db2 --- /dev/null +++ b/tensorflow_lite_support/cc/port/build_defs.bzl @@ -0,0 +1,30 @@ +""".bzl file for TFLite Support open source build configs.""" + +load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library") + +def provided_args(**kwargs): + """Returns the keyword arguments omitting None arguments.""" + return {k: v for k, v in kwargs.items() if v != None} + +def support_cc_proto_library(name, srcs, visibility = None, deps = [], cc_deps = [], testonly = 0): + """Generate cc_proto_library for TFLite Support open source version. + + Args: + name: the name of the cc_proto_library. + srcs: the .proto files of the cc_proto_library for Bazel use. + visibility: visibility of this target. + deps: a list of dependency labels for Bazel use; must be cc_proto_library. + testonly: test only proto or not. + """ + _ignore = [deps] + cc_proto_library(**provided_args( + name = name, + srcs = srcs, + visibility = visibility, + deps = cc_deps, + testonly = testonly, + cc_libs = ["@com_google_protobuf//:protobuf"], + protoc = "@com_google_protobuf//:protoc", + default_runtime = "@com_google_protobuf//:protobuf", + alwayslink = 1, + )) diff --git a/tensorflow_lite_support/cc/port/default/BUILD b/tensorflow_lite_support/cc/port/default/BUILD new file mode 100644 index 00000000..3f6e9e93 --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/BUILD @@ -0,0 +1,50 @@ +package( + default_visibility = [ + "//tensorflow_lite_support/cc/port:__pkg__", + "//tensorflow_lite_support/cc/test:__pkg__", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "statusor", + srcs = ["statusor.cc"], + hdrs = [ + "statusor.h", + "statusor_internals.h", + ], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", + "@com_google_absl//absl/utility", + "@com_google_glog//:glog", + ], +) + +cc_library( + name = "status_macros", + hdrs = [ + "status_macros.h", + ], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "tflite_wrapper", + srcs = ["tflite_wrapper.cc"], + hdrs = [ + "tflite_wrapper.h", + ], + deps = [ + "//tensorflow_lite_support/cc/port:status_macros", + "@com_google_absl//absl/status", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/experimental/acceleration/configuration:configuration_cc_proto", + ], +) diff --git a/tensorflow_lite_support/cc/port/default/status_macros.h b/tensorflow_lite_support/cc/port/default/status_macros.h new file mode 100644 index 00000000..47476c9c --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/status_macros.h @@ -0,0 +1,215 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is forked from absl. + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_ + +#include "absl/base/optimization.h" +#include "absl/status/status.h" + +// Evaluates an expression that produces a `absl::Status`. If the status is not +// ok, returns it from the current function. +// +// For example: +// absl::Status MultiStepFunction() { +// RETURN_IF_ERROR(Function(args...)); +// RETURN_IF_ERROR(foo.Method(args...)); +// return absl::OkStatus(); +// } +#define RETURN_IF_ERROR(expr) \ + STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (::tflite::support::status_macro_internal::StatusAdaptorForMacros \ + status_macro_internal_adaptor = {(expr)}) { \ + } else /* NOLINT */ \ + return status_macro_internal_adaptor.Consume() + +// Executes an expression `rexpr` that returns a `tflite::support::StatusOr<T>`. +// On OK, moves its value into the variable defined by `lhs`, otherwise returns +// from the current function. By default the error status is returned +// unchanged, but it may be modified by an `error_expression`. If there is an +// error, `lhs` is not evaluated; thus any side effects that `lhs` may have +// only occur in the success case. +// +// Interface: +// +// ASSIGN_OR_RETURN(lhs, rexpr) +// ASSIGN_OR_RETURN(lhs, rexpr, error_expression); +// +// WARNING: if lhs is parenthesized, the parentheses are removed. See examples +// for more details. +// +// WARNING: expands into multiple statements; it cannot be used in a single +// statement (e.g. as the body of an if statement without {})! +// +// Example: Declaring and initializing a new variable (ValueType can be anything +// that can be initialized with assignment, including references): +// ASSIGN_OR_RETURN(ValueType value, MaybeGetValue(arg)); +// +// Example: Assigning to an existing variable: +// ValueType value; +// ASSIGN_OR_RETURN(value, MaybeGetValue(arg)); +// +// Example: Assigning to an expression with side effects: +// MyProto data; +// ASSIGN_OR_RETURN(*data.mutable_str(), MaybeGetValue(arg)); +// // No field "str" is added on error. +// +// Example: Assigning to a std::unique_ptr. +// ASSIGN_OR_RETURN(std::unique_ptr<T> ptr, MaybeGetPtr(arg)); +// +// Example: Assigning to a map. Because of C preprocessor +// limitation, the type used in ASSIGN_OR_RETURN cannot contain comma, so +// wrap lhs in parentheses: +// ASSIGN_OR_RETURN((absl::flat_hash_map<Foo, Bar> my_map), GetMap()); +// Or use auto if the type is obvious enough: +// ASSIGN_OR_RETURN(const auto& my_map, GetMapRef()); +// +// Example: Assigning to structured bindings. The same situation with comma as +// in map, so wrap the statement in parentheses. +// ASSIGN_OR_RETURN((const auto& [first, second]), GetPair()); + +#define ASSIGN_OR_RETURN(...) \ + STATUS_MACROS_IMPL_GET_VARIADIC_((__VA_ARGS__, \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \ + (__VA_ARGS__) + +// ================================================================= +// == Implementation details, do not rely on anything below here. == +// ================================================================= + +// Some builds do not support C++14 fully yet, using C++11 constexpr technique. +constexpr bool TFLSHasPotentialConditionalOperator(const char* lhs, int index) { + return (index == -1 + ? false + : (lhs[index] == '?' + ? true + : TFLSHasPotentialConditionalOperator(lhs, index - 1))); +} + +// MSVC incorrectly expands variadic macros, splice together a macro call to +// work around the bug. +#define STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME +#define STATUS_MACROS_IMPL_GET_VARIADIC_(args) \ + STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args + +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, _) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ + STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ + error_expression) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ + error_expression) \ + auto statusor = (rexpr); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + ::absl::Status _(std::move(statusor).status()); \ + (void)_; /* error_expression is allowed to not use this variable */ \ + return (error_expression); \ + } \ + { \ + static_assert( \ + #lhs[0] != '(' || #lhs[sizeof(#lhs) - 2] != ')' || \ + !TFLSHasPotentialConditionalOperator(#lhs, sizeof(#lhs) - 2), \ + "Identified potential conditional operator, consider not " \ + "using ASSIGN_OR_RETURN"); \ + } \ + STATUS_MACROS_IMPL_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \ + std::move(statusor).value() + +// Internal helpers for macro expansion. +#define STATUS_MACROS_IMPL_EAT(...) +#define STATUS_MACROS_IMPL_REM(...) __VA_ARGS__ +#define STATUS_MACROS_IMPL_EMPTY() + +// Internal helpers for emptyness arguments check. +#define STATUS_MACROS_IMPL_IS_EMPTY_INNER(...) \ + STATUS_MACROS_IMPL_IS_EMPTY_INNER_I(__VA_ARGS__, 0, 1) +#define STATUS_MACROS_IMPL_IS_EMPTY_INNER_I(e0, e1, is_empty, ...) is_empty + +#define STATUS_MACROS_IMPL_IS_EMPTY(...) \ + STATUS_MACROS_IMPL_IS_EMPTY_I(__VA_ARGS__) +#define STATUS_MACROS_IMPL_IS_EMPTY_I(...) \ + STATUS_MACROS_IMPL_IS_EMPTY_INNER(_, ##__VA_ARGS__) + +// Internal helpers for if statement. +#define STATUS_MACROS_IMPL_IF_1(_Then, _Else) _Then +#define STATUS_MACROS_IMPL_IF_0(_Then, _Else) _Else +#define STATUS_MACROS_IMPL_IF(_Cond, _Then, _Else) \ + STATUS_MACROS_IMPL_CONCAT_(STATUS_MACROS_IMPL_IF_, _Cond) \ + (_Then, _Else) + +// Expands to 1 if the input is parenthesized. Otherwise expands to 0. +#define STATUS_MACROS_IMPL_IS_PARENTHESIZED(...) \ + STATUS_MACROS_IMPL_IS_EMPTY(STATUS_MACROS_IMPL_EAT __VA_ARGS__) + +// If the input is parenthesized, removes the parentheses. Otherwise expands to +// the input unchanged. +#define STATUS_MACROS_IMPL_UNPARENTHESIZE_IF_PARENTHESIZED(...) \ + STATUS_MACROS_IMPL_IF(STATUS_MACROS_IMPL_IS_PARENTHESIZED(__VA_ARGS__), \ + STATUS_MACROS_IMPL_REM, STATUS_MACROS_IMPL_EMPTY()) \ + __VA_ARGS__ + +// Internal helper for concatenating macro values. +#define STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y +#define STATUS_MACROS_IMPL_CONCAT_(x, y) STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) + +// The GNU compiler emits a warning for code like: +// +// if (foo) +// if (bar) { } else baz; +// +// because it thinks you might want the else to bind to the first if. This +// leads to problems with code like: +// +// if (do_expr) RETURN_IF_ERROR(expr) << "Some message"; +// +// The "switch (0) case 0:" idiom is used to suppress this. +#define STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + switch (0) \ + case 0: \ + default: // NOLINT + +namespace tflite { +namespace support { +namespace status_macro_internal { + +// Provides a conversion to bool so that it can be used inside an if statement +// that declares a variable. +class StatusAdaptorForMacros { + public: + StatusAdaptorForMacros(const ::absl::Status& status) // NOLINT + : status_(status) {} + + StatusAdaptorForMacros(::absl::Status&& status) // NOLINT + : status_(std::move(status)) {} + + StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; + StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete; + + explicit operator bool() const { return ABSL_PREDICT_TRUE(status_.ok()); } + + ::absl::Status&& Consume() { return std::move(status_); } + + private: + ::absl::Status status_; +}; + +} // namespace status_macro_internal +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUS_MACROS_H_ diff --git a/tensorflow_lite_support/cc/port/default/statusor.cc b/tensorflow_lite_support/cc/port/default/statusor.cc new file mode 100644 index 00000000..5cf1196a --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/statusor.cc @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is forked from absl. + +#include "tensorflow_lite_support/cc/port/default/statusor.h" + +#include <utility> + +#include <glog/logging.h> +#include "absl/strings/str_cat.h" + +namespace tflite { +namespace support { + +BadStatusOrAccess::BadStatusOrAccess(absl::Status status) + : status_(std::move(status)) {} + +BadStatusOrAccess::~BadStatusOrAccess() = default; + +const char* BadStatusOrAccess::what() const noexcept { + return "Bad StatusOr access"; +} + +const absl::Status& BadStatusOrAccess::status() const { return status_; } + +namespace internal_statusor { + +void Helper::HandleInvalidStatusCtorArg(absl::Status* status) { + const char* kMessage = + "An OK status is not a valid constructor argument to StatusOr<T>"; + LOG(DFATAL) << kMessage; + // In optimized builds, we will fall back to ::util::error::INTERNAL. + *status = absl::InternalError(kMessage); +} + +void Helper::Crash(const absl::Status& status) { + LOG(FATAL) << "Attempting to fetch value instead of handling error " + << status; + _exit(1); +} + +void ThrowBadStatusOrAccess(absl::Status status) { +#ifdef ABSL_HAVE_EXCEPTIONS + throw BadStatusOrAccess(std::move(status)); +#else + LOG(FATAL) << "Attempting to fetch value instead of handling error " + << status; +#endif +} + +} // namespace internal_statusor +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/port/default/statusor.h b/tensorflow_lite_support/cc/port/default/statusor.h new file mode 100644 index 00000000..4273e1ce --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/statusor.h @@ -0,0 +1,574 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is forked from absl. + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_ + +#include <exception> +#include <initializer_list> +#include <new> +#include <string> +#include <type_traits> +#include <utility> + +#include "absl/base/optimization.h" +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/types/variant.h" +#include "absl/utility/utility.h" +#include "tensorflow_lite_support/cc/port/default/statusor_internals.h" + +namespace tflite { +namespace support { + +#ifndef SWIG +class BadStatusOrAccess : public std::exception { + public: + explicit BadStatusOrAccess(absl::Status status); + ~BadStatusOrAccess() override; + const char* what() const noexcept override; + const absl::Status& status() const; + + private: + absl::Status status_; +}; +#endif // !SWIG + +// Returned StatusOr objects may not be ignored. +// Note: Disabled for SWIG as it doesn't parse attributes correctly. Codesearch +// doesn't handle ifdefs as part of a class definitions (b/6995610), so we use a +// forward declaration. +#ifndef SWIG +template <typename T> +class ABSL_MUST_USE_RESULT StatusOr; +#endif + +template <typename T> +class StatusOr : private internal_statusor::StatusOrData<T>, + private internal_statusor::CopyCtorBase<T>, + private internal_statusor::MoveCtorBase<T>, + private internal_statusor::CopyAssignBase<T>, + private internal_statusor::MoveAssignBase<T> { + template <typename U> + friend class StatusOr; + + typedef internal_statusor::StatusOrData<T> Base; + + public: + typedef T value_type; + + // Constructs a new StatusOr with Status::UNKNOWN status. This is marked + // 'explicit' to try to catch cases like 'return {};', where people think + // tflite::support::StatusOr<std::vector<int>> will be initialized with an + // empty vector, instead of a Status::UNKNOWN status. + explicit StatusOr(); + + // StatusOr<T> is copy constructible if T is copy constructible. + StatusOr(const StatusOr&) = default; + // StatusOr<T> is copy assignable if T is copy constructible and copy + // assignable. + StatusOr& operator=(const StatusOr&) = default; + +#ifndef SWIG + + // StatusOr<T> is move constructible if T is move constructible. + StatusOr(StatusOr&&) = default; + // StatusOr<T> is moveAssignable if T is move constructible and move + // assignable. + StatusOr& operator=(StatusOr&&) = default; + + // Converting constructors from StatusOr<U>, when T is constructible from U. + // To avoid ambiguity, they are disabled if T is also constructible from + // StatusOr<U>. Explicit iff the corresponding construction of T from U is + // explicit. + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, + std::is_constructible<T, const U&>, + std::is_convertible<const U&, T>, + absl::negation< + internal_statusor::IsConstructibleOrConvertibleFromStatusOr< + T, U>>>::value, + int> = 0> + StatusOr(const StatusOr<U>& other) // NOLINT + : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {} + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, + std::is_constructible<T, const U&>, + absl::negation<std::is_convertible<const U&, T>>, + absl::negation< + internal_statusor::IsConstructibleOrConvertibleFromStatusOr< + T, U>>>::value, + int> = 0> + explicit StatusOr(const StatusOr<U>& other) + : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {} + + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, + std::is_convertible<U&&, T>, + absl::negation< + internal_statusor::IsConstructibleOrConvertibleFromStatusOr< + T, U>>>::value, + int> = 0> + StatusOr(StatusOr<U>&& other) // NOLINT + : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {} + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, + absl::negation<std::is_convertible<U&&, T>>, + absl::negation< + internal_statusor::IsConstructibleOrConvertibleFromStatusOr< + T, U>>>::value, + int> = 0> + explicit StatusOr(StatusOr<U>&& other) + : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {} + + // Conversion copy/move assignment operator, T must be constructible and + // assignable from U. Only enable if T cannot be directly assigned from + // StatusOr<U>. + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, + std::is_constructible<T, const U&>, + std::is_assignable<T, const U&>, + absl::negation< + internal_statusor:: + IsConstructibleOrConvertibleOrAssignableFromStatusOr< + T, U>>>::value, + int> = 0> + StatusOr& operator=(const StatusOr<U>& other) { + this->Assign(other); + return *this; + } + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, + std::is_assignable<T, U&&>, + absl::negation< + internal_statusor:: + IsConstructibleOrConvertibleOrAssignableFromStatusOr< + T, U>>>::value, + int> = 0> + StatusOr& operator=(StatusOr<U>&& other) { + this->Assign(std::move(other)); + return *this; + } + +#endif // SWIG + + // Constructs a new StatusOr with the given value. After calling this + // constructor, this->ok() will be true and the contained value may be + // retrieved with value(), operator*(), or operator->(). + // + // NOTE: Not explicit - we want to use StatusOr<T> as a return type + // so it is convenient and sensible to be able to do 'return T()' + // when the return type is StatusOr<T>. + // + // REQUIRES: T is copy constructible. + // TODO(b/113125838): Replace this constructor with a direct-initialization + // constructor. + StatusOr(const T& value); + + // Constructs a new StatusOr with the given non-ok status. After calling this + // constructor, this->ok() will be false and calls to value() will CHECK-fail. + // + // NOTE: Not explicit - we want to use StatusOr<T> as a return + // value, so it is convenient and sensible to be able to do 'return + // Status()' when the return type is StatusOr<T>. + // + // REQUIRES: !status.ok(). This requirement is DCHECKed. + // In optimized builds, passing util::OkStatus() here will have the effect + // of passing util::error::INTERNAL as a fallback. + StatusOr(const absl::Status& status); + StatusOr& operator=(const absl::Status& status); + +#ifndef SWIG + // Perfect-forwarding value assignment operator. + // If `*this` contains a `T` value before the call, the contained value is + // assigned from `std::forward<U>(v)`; Otherwise, it is directly-initialized + // from `std::forward<U>(v)`. + // This function does not participate in overload unless: + // 1. `std::is_constructible_v<T, U>` is true, + // 2. `std::is_assignable_v<T&, U>` is true. + // 3. `std::is_same_v<StatusOr<T>, std::remove_cvref_t<U>>` is false. + // 4. Assigning `U` to `T` is not ambiguous: + // If `U` is `StatusOr<V>` and `T` is constructible and assignable from + // both `StatusOr<V>` and `V`, the assignment is considered bug-prone and + // ambiguous thus will fail to compile. For example: + // StatusOr<bool> s1 = true; // s1.ok() && *s1 == true + // StatusOr<bool> s2 = false; // s2.ok() && *s2 == false + // s1 = s2; // ambiguous, `s1 = *s2` or `s1 = bool(s2)`? + template < + typename U = T, + typename = typename std::enable_if<absl::conjunction< + std::is_constructible<T, U&&>, std::is_assignable<T&, U&&>, + internal_statusor::IsForwardingAssignmentValid<T, U&&>>::value>::type> + StatusOr& operator=(U&& v) { + this->Assign(std::forward<U>(v)); + return *this; + } + + // Similar to the `const T&` overload. + // + // REQUIRES: T is move constructible. + StatusOr(T&& value); + + // RValue versions of the operations declared above. + StatusOr(absl::Status&& status); + StatusOr& operator=(absl::Status&& status); + + // Constructs the inner value T in-place using the provided args, using the + // T(args...) constructor. + template <typename... Args> + explicit StatusOr(absl::in_place_t, Args&&... args); + template <typename U, typename... Args> + explicit StatusOr(absl::in_place_t, std::initializer_list<U> ilist, + Args&&... args); + + // Constructs the inner value T in-place using the provided args, using the + // T(U) (direct-initialization) constructor. Only valid if T can be + // constructed from a U. Can accept move or copy constructors. Explicit if + // U is not convertible to T. To avoid ambiguity, this is disabled if U is + // a StatusOr<J>, where J is convertible to T. + // Style waiver for implicit conversion granted in cl/209187539. + template <typename U = T, + absl::enable_if_t< + absl::conjunction< + internal_statusor::IsDirectInitializationValid<T, U&&>, + std::is_constructible<T, U&&>, + std::is_convertible<U&&, T>>::value, + int> = 0> + StatusOr(U&& u) // NOLINT + : StatusOr(absl::in_place, std::forward<U>(u)) {} + + template <typename U = T, + absl::enable_if_t< + absl::conjunction< + internal_statusor::IsDirectInitializationValid<T, U&&>, + std::is_constructible<T, U&&>, + absl::negation<std::is_convertible<U&&, T>>>::value, + int> = 0> + explicit StatusOr(U&& u) // NOLINT + : StatusOr(absl::in_place, std::forward<U>(u)) {} + +#endif // SWIG + + // Returns this->status().ok() + ABSL_MUST_USE_RESULT bool ok() const { return this->status_.ok(); } + + // Returns a reference to our status. If this contains a T, then + // returns util::OkStatus(). +#ifdef SWIG + const ::util::Status& status() const; +#else // SWIG + const absl::Status& status() const&; + absl::Status status() &&; +#endif // SWIG + + // Returns a reference to the held value if `this->ok()`. Otherwise, throws + // `absl::BadStatusOrAccess` if exception is enabled, or `LOG(FATAL)` if + // exception is disabled. + // If you have already checked the status using `this->ok()` or + // `operator bool()`, you probably want to use `operator*()` or `operator->()` + // to access the value instead of `value`. + // Note: for value types that are cheap to copy, prefer simple code: + // + // T value = statusor.value(); + // + // Otherwise, if the value type is expensive to copy, but can be left + // in the StatusOr, simply assign to a reference: + // + // T& value = statusor.value(); // or `const T&` + // + // Otherwise, if the value type supports an efficient move, it can be + // used as follows: + // + // T value = std::move(statusor).value(); + // + // The `std::move` on statusor instead of on the whole expression enables + // warnings about possible uses of the statusor object after the move. +#ifdef SWIG + const T& value() const; +#else // SWIG + const T& value() const&; + T& value() &; + const T&& value() const&&; + T&& value() &&; +#endif // SWIG + +#ifndef SWIG + // Returns a reference to the current value. + // + // REQUIRES: this->ok() == true, otherwise the behavior is undefined. + // + // Use this->ok() or `operator bool()` to verify that there is a current + // value. Alternatively, see value() for a similar API that guarantees + // CHECK-failing if there is no current value. + const T& operator*() const&; + T& operator*() &; + const T&& operator*() const&&; + T&& operator*() &&; +#endif // SWIG + +#ifndef SWIG + // Returns a pointer to the current value. + // + // REQUIRES: this->ok() == true, otherwise the behavior is undefined. + // + // Use this->ok() or `operator bool()` to verify that there is a current + // value. + const T* operator->() const; + T* operator->(); +#endif // SWIG + +#ifndef SWIG + // Returns a copy of the current value if this->ok() == true. Otherwise + // returns a default value. + template <typename U> + T value_or(U&& default_value) const&; + template <typename U> + T value_or(U&& default_value) &&; +#endif // SWIG + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; + +#ifndef SWIG + // Reconstructs the inner value T in-place using the provided args, using the + // T(args...) constructor. Returns reference to the reconstructed `T`. + template <typename... Args> + T& emplace(Args&&... args) { + if (ok()) { + this->Clear(); + this->MakeValue(std::forward<Args>(args)...); + } else { + this->MakeValue(std::forward<Args>(args)...); + this->status_ = absl::OkStatus(); + } + return this->data_; + } + + template < + typename U, typename... Args, + absl::enable_if_t< + std::is_constructible<T, std::initializer_list<U>&, Args&&...>::value, + int> = 0> + T& emplace(std::initializer_list<U> ilist, Args&&... args) { + if (ok()) { + this->Clear(); + this->MakeValue(ilist, std::forward<Args>(args)...); + } else { + this->MakeValue(ilist, std::forward<Args>(args)...); + this->status_ = absl::OkStatus(); + } + return this->data_; + } +#endif // SWIG + + private: +#ifndef SWIG + using internal_statusor::StatusOrData<T>::Assign; + template <typename U> + void Assign(const StatusOr<U>& other); + template <typename U> + void Assign(StatusOr<U>&& other); +#endif // SWIG +}; + +#ifndef SWIG +//////////////////////////////////////////////////////////////////////////////// +// Implementation details for StatusOr<T> + +template <typename T> +tflite::support::StatusOr<T>::StatusOr() + : Base(absl::Status(absl::StatusCode::kUnknown, "")) {} + +template <typename T> +tflite::support::StatusOr<T>::StatusOr(const T& value) : Base(value) {} + +template <typename T> +tflite::support::StatusOr<T>::StatusOr(const absl::Status& status) + : Base(status) {} + +template <typename T> +tflite::support::StatusOr<T>& StatusOr<T>::operator=( + const absl::Status& status) { + this->Assign(status); + return *this; +} + +template <typename T> +tflite::support::StatusOr<T>::StatusOr(T&& value) : Base(std::move(value)) {} + +template <typename T> +tflite::support::StatusOr<T>::StatusOr(absl::Status&& status) + : Base(std::move(status)) {} + +template <typename T> +tflite::support::StatusOr<T>& StatusOr<T>::operator=(absl::Status&& status) { + this->Assign(std::move(status)); + return *this; +} + +template <typename T> +template <typename U> +inline void StatusOr<T>::Assign(const StatusOr<U>& other) { + if (other.ok()) { + this->Assign(other.value()); + } else { + this->Assign(other.status()); + } +} + +template <typename T> +template <typename U> +inline void StatusOr<T>::Assign(StatusOr<U>&& other) { + if (other.ok()) { + this->Assign(std::move(other).value()); + } else { + this->Assign(std::move(other).status()); + } +} +template <typename T> +template <typename... Args> +tflite::support::StatusOr<T>::StatusOr(absl::in_place_t, Args&&... args) + : Base(absl::in_place, std::forward<Args>(args)...) {} + +template <typename T> +template <typename U, typename... Args> +tflite::support::StatusOr<T>::StatusOr(absl::in_place_t, + std::initializer_list<U> ilist, + Args&&... args) + : Base(absl::in_place, ilist, std::forward<Args>(args)...) {} + +template <typename T> +const absl::Status& StatusOr<T>::status() const& { + return this->status_; +} +template <typename T> +absl::Status StatusOr<T>::status() && { + return ok() ? absl::OkStatus() : std::move(this->status_); +} + +template <typename T> +const T& StatusOr<T>::value() const& { + if (!this->ok()) internal_statusor::ThrowBadStatusOrAccess(this->status_); + return this->data_; +} + +template <typename T> +T& StatusOr<T>::value() & { + if (!this->ok()) internal_statusor::ThrowBadStatusOrAccess(this->status_); + return this->data_; +} + +template <typename T> +const T&& StatusOr<T>::value() const&& { + if (!this->ok()) { + internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_)); + } + return std::move(this->data_); +} + +template <typename T> +T&& StatusOr<T>::value() && { + if (!this->ok()) { + internal_statusor::ThrowBadStatusOrAccess(std::move(this->status_)); + } + return std::move(this->data_); +} + +template <typename T> +const T& StatusOr<T>::operator*() const& { + this->EnsureOk(); + return this->data_; +} + +template <typename T> +T& StatusOr<T>::operator*() & { + this->EnsureOk(); + return this->data_; +} + +template <typename T> +const T&& StatusOr<T>::operator*() const&& { + this->EnsureOk(); + return std::move(this->data_); +} + +template <typename T> +T&& StatusOr<T>::operator*() && { + this->EnsureOk(); + return std::move(this->data_); +} + +template <typename T> +const T* StatusOr<T>::operator->() const { + this->EnsureOk(); + return &this->data_; +} + +template <typename T> +T* StatusOr<T>::operator->() { + this->EnsureOk(); + return &this->data_; +} + +template <typename T> +template <typename U> +T StatusOr<T>::value_or(U&& default_value) const& { + if (ok()) { + return this->data_; + } + return std::forward<U>(default_value); +} + +template <typename T> +template <typename U> +T StatusOr<T>::value_or(U&& default_value) && { + if (ok()) { + return std::move(this->data_); + } + return std::forward<U>(default_value); +} + +template <typename T> +void StatusOr<T>::IgnoreError() const { + // no-op +} + +#endif // SWIG + +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_H_ diff --git a/tensorflow_lite_support/cc/port/default/statusor_internals.h b/tensorflow_lite_support/cc/port/default/statusor_internals.h new file mode 100644 index 00000000..56d46616 --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/statusor_internals.h @@ -0,0 +1,409 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is forked from absl. + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_INTERNALS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_INTERNALS_H_ + +#include <type_traits> +#include <utility> + +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/utility/utility.h" + +namespace tflite { +namespace support { + +template <typename T> +class ABSL_MUST_USE_RESULT StatusOr; + +namespace internal_statusor { + +// Detects whether `T` is constructible or convertible from `StatusOr<U>`. +template <typename T, typename U> +using IsConstructibleOrConvertibleFromStatusOr = + absl::disjunction<std::is_constructible<T, StatusOr<U>&>, + std::is_constructible<T, const StatusOr<U>&>, + std::is_constructible<T, StatusOr<U>&&>, + std::is_constructible<T, const StatusOr<U>&&>, + std::is_convertible<StatusOr<U>&, T>, + std::is_convertible<const StatusOr<U>&, T>, + std::is_convertible<StatusOr<U>&&, T>, + std::is_convertible<const StatusOr<U>&&, T>>; + +// Detects whether `T` is constructible or convertible or assignable from +// `StatusOr<U>`. +template <typename T, typename U> +using IsConstructibleOrConvertibleOrAssignableFromStatusOr = + absl::disjunction<IsConstructibleOrConvertibleFromStatusOr<T, U>, + std::is_assignable<T&, StatusOr<U>&>, + std::is_assignable<T&, const StatusOr<U>&>, + std::is_assignable<T&, StatusOr<U>&&>, + std::is_assignable<T&, const StatusOr<U>&&>>; + +// Detects whether direct initializing `StatusOr<T>` from `U` is ambiguous, i.e. +// when `U` is `StatusOr<V>` and `T` is constructible or convertible from `V`. +template <typename T, typename U> +struct IsDirectInitializationAmbiguous + : public absl::conditional_t< + std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>, + U>::value, + std::false_type, + IsDirectInitializationAmbiguous< + T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {}; + +template <typename T, typename V> +struct IsDirectInitializationAmbiguous<T, tflite::support::StatusOr<V>> + : public IsConstructibleOrConvertibleFromStatusOr<T, V> {}; + +// Checks against the constraints of the direction initialization, i.e. when +// `StatusOr<T>::StatusOr(U&&)` should participate in overload resolution. +template <typename T, typename U> +using IsDirectInitializationValid = absl::disjunction< + // Short circuits if T is basically U. + std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>, + absl::negation<absl::disjunction< + std::is_same<tflite::support::StatusOr<T>, + absl::remove_cv_t<absl::remove_reference_t<U>>>, + std::is_same<absl::Status, + absl::remove_cv_t<absl::remove_reference_t<U>>>, + std::is_same<absl::in_place_t, + absl::remove_cv_t<absl::remove_reference_t<U>>>, + IsDirectInitializationAmbiguous<T, U>>>>; + +// This trait detects whether `StatusOr<T>::operator=(U&&)` is ambiguous, which +// is equivalent to whether all the following conditions are met: +// 1. `U` is `StatusOr<V>`. +// 2. `T` is constructible and assignable from `V`. +// 3. `T` is constructible and assignable from `U` (i.e. `StatusOr<V>`). +// For example, the following code is considered ambiguous: +// (`T` is `bool`, `U` is `StatusOr<bool>`, `V` is `bool`) +// StatusOr<bool> s1 = true; // s1.ok() && s1.ValueOrDie() == true +// StatusOr<bool> s2 = false; // s2.ok() && s2.ValueOrDie() == false +// s1 = s2; // ambiguous, `s1 = s2.ValueOrDie()` or `s1 = bool(s2)`? +template <typename T, typename U> +struct IsForwardingAssignmentAmbiguous + : public absl::conditional_t< + std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>, + U>::value, + std::false_type, + IsForwardingAssignmentAmbiguous< + T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {}; + +template <typename T, typename U> +struct IsForwardingAssignmentAmbiguous<T, tflite::support::StatusOr<U>> + : public IsConstructibleOrConvertibleOrAssignableFromStatusOr<T, U> {}; + +// Checks against the constraints of the forwarding assignment, i.e. whether +// `StatusOr<T>::operator(U&&)` should participate in overload resolution. +template <typename T, typename U> +using IsForwardingAssignmentValid = absl::disjunction< + // Short circuits if T is basically U. + std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>, + absl::negation<absl::disjunction< + std::is_same<tflite::support::StatusOr<T>, + absl::remove_cv_t<absl::remove_reference_t<U>>>, + std::is_same<absl::Status, + absl::remove_cv_t<absl::remove_reference_t<U>>>, + std::is_same<absl::in_place_t, + absl::remove_cv_t<absl::remove_reference_t<U>>>, + IsForwardingAssignmentAmbiguous<T, U>>>>; + +class Helper { + public: + // Move type-agnostic error handling to the .cc. + static void HandleInvalidStatusCtorArg(absl::Status*); + ABSL_ATTRIBUTE_NORETURN static void Crash(const absl::Status& status); +}; + +// Construct an instance of T in `p` through placement new, passing Args... to +// the constructor. +// This abstraction is here mostly for the gcc performance fix. +template <typename T, typename... Args> +void PlacementNew(void* p, Args&&... args) { +#if defined(__GNUC__) && !defined(__clang__) + // Teach gcc that 'p' cannot be null, fixing code size issues. + if (p == nullptr) __builtin_unreachable(); +#endif + new (p) T(std::forward<Args>(args)...); +} + +// Helper base class to hold the data and all operations. +// We move all this to a base class to allow mixing with the appropriate +// TraitsBase specialization. +template <typename T> +class StatusOrData { + template <typename U> + friend class StatusOrData; + + public: + StatusOrData() = delete; + + StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + StatusOrData(StatusOrData&& other) noexcept { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + template <typename U> + explicit StatusOrData(const StatusOrData<U>& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + template <typename U> + explicit StatusOrData(StatusOrData<U>&& other) { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + template <typename... Args> + explicit StatusOrData(absl::in_place_t, Args&&... args) + : data_(std::forward<Args>(args)...) { + MakeStatus(); + } + + explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } + explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } + + explicit StatusOrData(const absl::Status& status) : status_(status) { + EnsureNotOk(); + } + explicit StatusOrData(absl::Status&& status) : status_(std::move(status)) { + EnsureNotOk(); + } + + StatusOrData& operator=(const StatusOrData& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(other.data_); + else + Assign(other.status_); + return *this; + } + + StatusOrData& operator=(StatusOrData&& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(std::move(other.data_)); + else + Assign(std::move(other.status_)); + return *this; + } + + ~StatusOrData() { + if (ok()) { + status_.~Status(); + data_.~T(); + } else { + status_.~Status(); + } + } + + // TODO(b/140189837): Remove the SFINAE condition after cleanup. + template <typename U, + absl::enable_if_t<std::is_assignable<T&, U&&>::value, int> = 0> + void Assign(U&& value) { + if (ok()) { + data_ = std::forward<U>(value); + } else { + MakeValue(std::forward<U>(value)); + status_ = absl::OkStatus(); + } + } + + // TODO(b/140189837): Remove this after cleanup. + // This overload is to handle the case where `T` is a `const` type. + // `StatusOr` supports assignment for `const` types though it's forbidden by + // other standard types like `std::optional`. + template <typename U, + absl::enable_if_t<!std::is_assignable<T&, U&&>::value, int> = 0> + void Assign(U&& value) { + if (ok()) { + data_.~T(); + MakeValue(std::forward<U>(value)); + } else { + MakeValue(std::forward<U>(value)); + status_ = absl::OkStatus(); + } + } + + void Assign(const absl::Status& status) { + Clear(); + status_ = status; + EnsureNotOk(); + } + + void Assign(absl::Status&& status) { + Clear(); + status_ = std::move(status); + EnsureNotOk(); + } + + bool ok() const { return status_.ok(); } + + protected: + // status_ will always be active after the constructor. + // We make it a union to be able to initialize exactly how we need without + // waste. + // Eg. in the copy constructor we use the default constructor of Status in + // the ok() path to avoid an extra Ref call. + union { + absl::Status status_; + }; + + // data_ is active iff status_.ok()==true + struct Dummy {}; + union { + // When T is const, we need some non-const object we can cast to void* for + // the placement new. dummy_ is that object. + Dummy dummy_; + T data_; + }; + + void Clear() { + if (ok()) data_.~T(); + } + + void EnsureOk() const { + if (ABSL_PREDICT_FALSE(!ok())) Helper::Crash(status_); + } + + void EnsureNotOk() { + if (ABSL_PREDICT_FALSE(ok())) Helper::HandleInvalidStatusCtorArg(&status_); + } + + // Construct the value (ie. data_) through placement new with the passed + // argument. + template <typename... Arg> + void MakeValue(Arg&&... arg) { + internal_statusor::PlacementNew<T>(&dummy_, std::forward<Arg>(arg)...); + } + + // Construct the status (ie. status_) through placement new with the passed + // argument. + template <typename... Args> + void MakeStatus(Args&&... args) { + internal_statusor::PlacementNew<absl::Status>(&status_, + std::forward<Args>(args)...); + } +}; + +// Helper base classes to allow implicitly deleted constructors and assignment +// operators in `StatusOr`. For example, `CopyCtorBase` will explicitly delete +// the copy constructor when T is not copy constructible and `StatusOr` will +// inherit that behavior implicitly. +template <typename T, bool = std::is_copy_constructible<T>::value> +struct CopyCtorBase { + CopyCtorBase() = default; + CopyCtorBase(const CopyCtorBase&) = default; + CopyCtorBase(CopyCtorBase&&) = default; + CopyCtorBase& operator=(const CopyCtorBase&) = default; + CopyCtorBase& operator=(CopyCtorBase&&) = default; +}; + +template <typename T> +struct CopyCtorBase<T, false> { + CopyCtorBase() = default; + CopyCtorBase(const CopyCtorBase&) = delete; + CopyCtorBase(CopyCtorBase&&) = default; + CopyCtorBase& operator=(const CopyCtorBase&) = default; + CopyCtorBase& operator=(CopyCtorBase&&) = default; +}; + +template <typename T, bool = std::is_move_constructible<T>::value> +struct MoveCtorBase { + MoveCtorBase() = default; + MoveCtorBase(const MoveCtorBase&) = default; + MoveCtorBase(MoveCtorBase&&) = default; + MoveCtorBase& operator=(const MoveCtorBase&) = default; + MoveCtorBase& operator=(MoveCtorBase&&) = default; +}; + +template <typename T> +struct MoveCtorBase<T, false> { + MoveCtorBase() = default; + MoveCtorBase(const MoveCtorBase&) = default; + MoveCtorBase(MoveCtorBase&&) = delete; + MoveCtorBase& operator=(const MoveCtorBase&) = default; + MoveCtorBase& operator=(MoveCtorBase&&) = default; +}; + +template <typename T, bool = std::is_copy_constructible<T>::value&& + std::is_copy_assignable<T>::value> +struct CopyAssignBase { + CopyAssignBase() = default; + CopyAssignBase(const CopyAssignBase&) = default; + CopyAssignBase(CopyAssignBase&&) = default; + CopyAssignBase& operator=(const CopyAssignBase&) = default; + CopyAssignBase& operator=(CopyAssignBase&&) = default; +}; + +template <typename T> +struct CopyAssignBase<T, false> { + CopyAssignBase() = default; + CopyAssignBase(const CopyAssignBase&) = default; + CopyAssignBase(CopyAssignBase&&) = default; + CopyAssignBase& operator=(const CopyAssignBase&) = delete; + CopyAssignBase& operator=(CopyAssignBase&&) = default; +}; + +template <typename T, bool = std::is_move_constructible<T>::value&& + std::is_move_assignable<T>::value> +struct MoveAssignBase { + MoveAssignBase() = default; + MoveAssignBase(const MoveAssignBase&) = default; + MoveAssignBase(MoveAssignBase&&) = default; + MoveAssignBase& operator=(const MoveAssignBase&) = default; + MoveAssignBase& operator=(MoveAssignBase&&) = default; +}; + +template <typename T> +struct MoveAssignBase<T, false> { + MoveAssignBase() = default; + MoveAssignBase(const MoveAssignBase&) = default; + MoveAssignBase(MoveAssignBase&&) = default; + MoveAssignBase& operator=(const MoveAssignBase&) = default; + MoveAssignBase& operator=(MoveAssignBase&&) = delete; +}; + +void ThrowBadStatusOrAccess(absl::Status status); + +} // namespace internal_statusor +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_STATUSOR_INTERNALS_H_ diff --git a/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc new file mode 100644 index 00000000..548e679a --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h" + +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" + +namespace tflite { +namespace support { + +absl::Status TfLiteInterpreterWrapper::InitializeWithFallback( + std::function<absl::Status(std::unique_ptr<tflite::Interpreter>*)> + interpreter_initializer, + const tflite::proto::ComputeSettings& compute_settings) { + if (compute_settings.has_preference() || + compute_settings.has_tflite_settings()) { + return absl::UnimplementedError( + "Acceleration via ComputeSettings is not supported yet."); + } + RETURN_IF_ERROR(interpreter_initializer(&interpreter_)); + return interpreter_->AllocateTensors() != kTfLiteOk + ? absl::InternalError( + "TFLite interpreter: AllocateTensors() failed.") + : absl::OkStatus(); +} + +absl::Status TfLiteInterpreterWrapper::InvokeWithFallback( + const std::function<absl::Status(tflite::Interpreter* interpreter)>& + set_inputs) { + RETURN_IF_ERROR(set_inputs(interpreter_.get())); + return interpreter_->Invoke() != kTfLiteOk + ? absl::InternalError("TFLite interpreter: Invoke() failed.") + : absl::OkStatus(); +} + +absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() { + return interpreter_->Invoke() != kTfLiteOk + ? absl::InternalError("TFLite interpreter: Invoke() failed.") + : absl::OkStatus(); +} + +void TfLiteInterpreterWrapper::Cancel() { + // NOP +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/port/default/tflite_wrapper.h b/tensorflow_lite_support/cc/port/default/tflite_wrapper.h new file mode 100644 index 00000000..3fd489f7 --- /dev/null +++ b/tensorflow_lite_support/cc/port/default/tflite_wrapper.h @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_ + +#include <memory> +#include <utility> + +#include "absl/status/status.h" +#include "tensorflow/lite/experimental/acceleration/configuration/configuration.pb.h" +#include "tensorflow/lite/interpreter.h" + +namespace tflite { +namespace support { + +// Wrapper for a TfLiteInterpreter that may be accelerated[1]. This is NOT yet +// implemented: this class only provides a first, minimal interface in the +// meanwhile. +// +// [1] See tensorflow/lite/experimental/acceleration for more details. +class TfLiteInterpreterWrapper { + public: + TfLiteInterpreterWrapper() = default; + + virtual ~TfLiteInterpreterWrapper() = default; + + // Calls `interpreter_initializer` and then `AllocateTensors`. Future + // implementation of this method will attempt to apply the provided + // `compute_settings` with a graceful fallback in case a failure occurs. + // Note: before this gets implemented, do NOT call this method with non-empty + // `compute_settings` otherwise an unimplemented error occurs. + absl::Status InitializeWithFallback( + std::function<absl::Status(std::unique_ptr<tflite::Interpreter>*)> + interpreter_initializer, + const tflite::proto::ComputeSettings& compute_settings); + + // Calls `set_inputs` and then Invoke() on the interpreter. Future + // implementation of this method will perform a graceful fallback in case a + // failure occur due to the `compute_settings` provided at initialization + // time. + absl::Status InvokeWithFallback( + const std::function<absl::Status(tflite::Interpreter* interpreter)>& + set_inputs); + + // Calls Invoke() on the interpreter. Caller must have set up inputs + // before-hand. + absl::Status InvokeWithoutFallback(); + + // Cancels the current running TFLite invocation on CPU. This method is not + // yet implemented though it is safe to use it as it acts as a NOP. + void Cancel(); + + // Accesses the underlying interpreter for other methods. + tflite::Interpreter& operator*() { return *interpreter_; } + tflite::Interpreter* operator->() { return interpreter_.get(); } + tflite::Interpreter& operator*() const { return *interpreter_; } + tflite::Interpreter* operator->() const { return interpreter_.get(); } + tflite::Interpreter* get() const { return interpreter_.get(); } + + TfLiteInterpreterWrapper(const TfLiteInterpreterWrapper&) = delete; + TfLiteInterpreterWrapper& operator=(const TfLiteInterpreterWrapper&) = delete; + + private: + std::unique_ptr<tflite::Interpreter> interpreter_; +}; + +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_DEFAULT_TFLITE_WRAPPER_H_ diff --git a/tensorflow_lite_support/cc/port/gmock.h b/tensorflow_lite_support/cc/port/gmock.h new file mode 100644 index 00000000..5e4334db --- /dev/null +++ b/tensorflow_lite_support/cc/port/gmock.h @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ + +#include "gmock/gmock.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GMOCK_H_ diff --git a/tensorflow_lite_support/cc/port/gtest.h b/tensorflow_lite_support/cc/port/gtest.h new file mode 100644 index 00000000..dbe2e5e6 --- /dev/null +++ b/tensorflow_lite_support/cc/port/gtest.h @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ + +#include "gtest/gtest.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_GTEST_H_ diff --git a/tensorflow_lite_support/cc/port/integral_types.h b/tensorflow_lite_support/cc/port/integral_types.h new file mode 100644 index 00000000..76d9d503 --- /dev/null +++ b/tensorflow_lite_support/cc/port/integral_types.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_INTEGRAL_TYPES_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_INTEGRAL_TYPES_H_ + +// Add namespace here to avoid conflict with other libraries. +namespace tflite { + +typedef signed char schar; +typedef signed char int8; +typedef short int16; +typedef int int32; +typedef long long int64; + +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned int char32; +typedef unsigned long long uint64; +typedef unsigned long uword_t; + +#define GG_LONGLONG(x) x##LL +#define GG_ULONGLONG(x) x##ULL +#define GG_LL_FORMAT "ll" // As in "%lld". Note that "q" is poor form also. +#define GG_LL_FORMAT_W L"ll" + +typedef uint64 Fprint; +static const Fprint kIllegalFprint = 0; +static const Fprint kMaxFprint = GG_ULONGLONG(0xFFFFFFFFFFFFFFFF); + +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_INTEGRAL_TYPES_H_ diff --git a/tensorflow_lite_support/cc/port/status_macros.h b/tensorflow_lite_support/cc/port/status_macros.h new file mode 100644 index 00000000..3890c772 --- /dev/null +++ b/tensorflow_lite_support/cc/port/status_macros.h @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MACROS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MACROS_H_ + +#include "tensorflow_lite_support/cc/port/default/status_macros.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUS_MACROS_H_ diff --git a/tensorflow_lite_support/cc/port/statusor.h b/tensorflow_lite_support/cc/port/statusor.h new file mode 100644 index 00000000..f84c7568 --- /dev/null +++ b/tensorflow_lite_support/cc/port/statusor.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_ + +#include "tensorflow_lite_support/cc/port/default/statusor.h" +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_STATUSOR_H_ diff --git a/tensorflow_lite_support/cc/port/tflite_wrapper.h b/tensorflow_lite_support/cc/port/tflite_wrapper.h new file mode 100644 index 00000000..601df9b4 --- /dev/null +++ b/tensorflow_lite_support/cc/port/tflite_wrapper.h @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_PORT_TFLITE_WRAPPER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_PORT_TFLITE_WRAPPER_H_ + +#include "tensorflow_lite_support/cc/port/default/tflite_wrapper.h" + +#endif // TENSORFLOW_LITE_SUPPORT_CC_PORT_TFLITE_WRAPPER_H_ diff --git a/tensorflow_lite_support/cc/task/README.md b/tensorflow_lite_support/cc/task/README.md new file mode 100644 index 00000000..bd756a2e --- /dev/null +++ b/tensorflow_lite_support/cc/task/README.md @@ -0,0 +1,384 @@ +# TFLite Task library - C++ + +A flexible and ready-to-use library for common machine learning model types, +such as classification and detection. + +## Text Task Libraries + +### QuestionAnswerer + +`QuestionAnswerer` API is able to load +[Mobile BERT](https://tfhub.dev/tensorflow/mobilebert/1) or +[AlBert](https://tfhub.dev/tensorflow/albert_lite_base/1) TFLite models and +answer question based on context. + +Use the C++ API to answer questions as follows: + +```cc +using tflite::task::text::qa::BertQuestionAnswerer; +using tflite::task::text::qa::QaAnswer; +// Create API handler with Mobile Bert model. +auto qa_client = BertQuestionAnswerer::CreateBertQuestionAnswererFromFile("/path/to/mobileBertModel", "/path/to/vocab"); +// Or create API handler with Albert model. +// auto qa_client = BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile("/path/to/alBertModel", "/path/to/sentencePieceModel"); + + +std::string context = + "Nikola Tesla (Serbian Cyrillic: Никола Тесла; 10 " + "July 1856 – 7 January 1943) was a Serbian American inventor, electrical " + "engineer, mechanical engineer, physicist, and futurist best known for his " + "contributions to the design of the modern alternating current (AC) " + "electricity supply system."; +std::string question = "When was Nikola Tesla born?"; +// Run inference with `context` and a given `question` to the context, and get top-k +// answers ranked by logits. +const std::vector<QaAnswer> answers = qa_client->Answer(context, question); +// Access QaAnswer results. +for (const QaAnswer& item : answers) { + std::cout << absl::StrFormat("Text: %s logit=%f start=%d end=%d", item.text, + item.pos.logit, item.pos.start, item.pos.end) + << std::endl; +} +// Output: +// Text: 10 July 1856 logit=16.8527 start=17 end=19 +// ... (and more) +// +// So the top-1 answer is: "10 July 1856". +``` + +In the above code, `item.text` is the text content of an answer. We use a span +with closed interval `[item.pos.start, item.pos.end]` to denote predicted tokens +in the answer, and `item.pos.logit` is the sum of span logits to represent the +confidence score. + +### NLClassifier + +`NLClassifier` API is able to load any TFLite models for natural language +classaification task such as language detection or sentiment detection. + +The API expects a TFLite model with the following input/output tensor: +Input tensor0: + (kTfLiteString) - input of the model, accepts a string. +Output tensor0: + (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64) + - output scores for each class, if type is one of the Int types, + dequantize it to double +Output tensor1: optional + (kTfLiteString) + - output classname for each class, should be of the same length with + scores. If this tensor is not present, the API uses score indices as + classnames. +By default the API tries to find the input/output tensors with default +configurations in NLClassifierOptions, with tensor name prioritized over +tensor index. The option is configurable for different TFLite models. + +Use the C++ API to perform language ID classification as follows: + +```cc +using tflite::task::text::nlclassifier::NLClassifier; +using tflite::task::core::Category; +auto classifier = NLClassifier::CreateFromFileAndOptions("/path/to/model"); +// Or create a customized NLClassifierOptions +// NLClassifierOptions options = +// { +// .output_score_tensor_name = myOutputScoreTensorName, +// .output_label_tensor_name = myOutputLabelTensorName, +// } +// auto classifier = NLClassifier::CreateFromFileAndOptions("/path/to/model", options); +std::string context = "What language is this?"; +std::vector<Category> categories = classifier->Classify(context); +// Access category results. +for (const Categoryr& category : categories) { + std::cout << absl::StrFormat("Language: %s Probability: %f", category.class_name, category_.score) + << std::endl; +} +// Output: +// Language: en Probability=0.9 +// ... (and more) +// +// So the top-1 answer is 'en'. +``` + +## Vision Task Libraries + +### Image Classifier + +`ImageClassifier` accepts any TFLite image classification model (with optional, +but strongly recommended, TFLite Model Metadata) that conforms to the following +spec: + +Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`): + + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - only RGB inputs are supported (`channels` is required to be 3). + - if type is `kTfLiteFloat32`, `NormalizationOptions` are required to be + attached to the metadata for input normalization. + +At least one output tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`) with: + + - `N` classes and either 2 or 4 dimensions, i.e. `[1 x N]` or + `[1 x 1 x 1 x N]` + - optional (but recommended) label map(s) as AssociatedFile-s with type + TENSOR_AXIS_LABELS, containing one label per line. The first such + AssociatedFile (if any) is used to fill the `class_name` field of the + results. The `display_name` field is filled from the AssociatedFile (if + any) whose locale matches the `display_names_locale` field of the + `ImageClassifierOptions` used at creation time ("en" by default, i.e. + English). If none of these are available, only the `index` field of the + results will be filled. + +An example of such model can be found at: +https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1 + +Example usage: + +```cc +// More options are available (e.g. max number of results to return). At the +// very least, the model must be specified: +ImageClassifierOptions options; +options.mutable_model_file_with_metadata()->set_file_name( + "/path/to/model.tflite"); + +// Create an ImageClassifier instance from the options. +StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options); +// Check if an error occurred. +if (!image_classifier_or.ok()) { + std::cerr << "An error occurred during ImageClassifier creation: " + << image_classifier_or.status().message(); + return; +} +std::unique_ptr<ImageClassifier> image_classifier = + std::move(image_classifier_or.value()); + +// Prepare FrameBuffer input from e.g. image RGBA data, width and height: +std::unique_ptr<FrameBuffer> frame_buffer = + CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height}); + +// Run inference: +StatusOr<ClassificationResult> result_or = + image_classifier->Classify(*frame_buffer); +// Check if an error occurred. +if (!result_or.ok()) { + std::cerr << "An error occurred during classification: " + << result_or.status().message(); + return; +} +ClassificationResult result = result_or.value(); + +// Example value for 'result': +// +// classifications { +// classes { index: 934 score: 0.95 class_name: "cat" } +// classes { index: 948 score: 0.007 class_name: "dog" } +// classes { index: 927 score: 0.003 class_name: "fox" } +// head_index: 0 +// } +``` + +A CLI demo tool is also available [here][1] for easily trying out this API. + +### Object Detector + +`ObjectDetector` accepts any object detection TFLite model (with mandatory +TFLite Model Metadata) that conforms to the following spec (e.g. Single Shot +Detectors): + +Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`): + + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - only RGB inputs are supported (`channels` is required to be 3). + - if type is kTfLiteFloat32, `NormalizationOptions` are required to be + attached to the metadata for input normalization. + +Output tensors must be the 4 outputs (type: `kTfLiteFloat32`) of a +[`DetectionPostProcess`][2] op, i.e: + +* Locations: + + - of size `[num_results x 4]`, the inner array + representing bounding boxes in the form [top, left, right, bottom]. + - BoundingBoxProperties are required to be attached to the metadata + and must specify type=BOUNDARIES and coordinate_type=RATIO. + +* Classes: + + - of size `[num_results]`, each value representing the + integer index of a class. + - optional (but recommended) label map(s) can be attached as + AssociatedFile-s with type TENSOR_VALUE_LABELS, containing one label per + line. The first such AssociatedFile (if any) is used to fill the + `class_name` field of the results. The `display_name` field is filled + from the AssociatedFile (if any) whose locale matches the + `display_names_locale` field of the `ObjectDetectorOptions` used at + creation time ("en" by default, i.e. English). If none of these are + available, only the `index` field of the results will be filled. + +* Scores: + + - of size `[num_results]`, each value representing the score + of the detected object. + +* Number of results: + + - integer `num_results` as a tensor of size `[1]` + +An example of such model can be found at: +https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1 + +Example usage: + +```cc +// More options are available (e.g. max number of results to return). At the +// very least, the model must be specified: +ObjectDetectorOptions options; +options.mutable_model_file_with_metadata()->set_file_name( + "/path/to/model.tflite"); + +// Create an ObjectDetector instance from the options. +StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options); +// Check if an error occurred. +if (!object_detector_or.ok()) { + std::cerr << "An error occurred during ObjectDetector creation: " + << object_detector_or.status().message(); + return; +} +std::unique_ptr<ObjectDetector> object_detector = + std::move(object_detector_or.value()); + +// Prepare FrameBuffer input from e.g. image RGBA data, width and height: +std::unique_ptr<FrameBuffer> frame_buffer = + CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height}); + +// Run inference: +StatusOr<DetectionResult> result_or = object_detector->Detect(*frame_buffer); +// Check if an error occurred. +if (!result_or.ok()) { + std::cerr << "An error occurred during detection: " + << result_or.status().message(); + return; +} +DetectionResult result = result_or.value(); + +// Example value for 'result': +// +// detections { +// bounding_box { +// origin_x: 54 +// origin_y: 398 +// width: 393 +// height: 196 +// } +// classes { index: 16 score: 0.65 class_name: "cat" } +// } +// detections { +// bounding_box { +// origin_x: 602 +// origin_y: 157 +// width: 394 +// height: 447 +// } +// classes { index: 17 score: 0.45 class_name: "dog" } +// } +``` + +A CLI demo tool is available [here][3] for easily trying out this API. + +### Image Segmenter + +`ImageSegmenter` accepts any TFLite model (with optional, but strongly +recommended, TFLite Model Metadata) that conforms to the following spec: + +Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`): + + - image input of size `[batch x height x width x channels]`. + - batch inference is not supported (`batch` is required to be 1). + - only RGB inputs are supported (`channels` is required to be 3). + - if type is kTfLiteFloat32, `NormalizationOptions` are required to be + attached to the metadata for input normalization. + +Output tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`): + + - tensor of size `[batch x mask_height x mask_width x num_classes]`, where + `batch` is required to be 1, `mask_width` and `mask_height` are the + dimensions of the segmentation masks produced by the model, and + `num_classes` is the number of classes supported by the model. + - optional (but recommended) label map(s) can be attached as + AssociatedFile-s with type TENSOR_AXIS_LABELS, containing one label per + line. The first such AssociatedFile (if any) is used to fill the + `class_name` field of the results. The `display_name` field is filled + from the AssociatedFile (if any) whose locale matches the + `display_names_locale` field of the `ImageSegmenterOptions` used at + creation time ("en" by default, i.e. English). If none of these are + available, only the `index` field of the results will be filled. + +An example of such model can be found at: +https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1 + +Example usage: + +```cc +// More options are available to select between return a single category mask +// or multiple confidence masks during post-processing. +ImageSegmenterOptions options; +options.mutable_model_file_with_metadata()->set_file_name( + "/path/to/model.tflite"); + +// Create an ImageSegmenter instance from the options. +StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = + ImageSegmenter::CreateFromOptions(options); +// Check if an error occurred. +if (!image_segmenter_or.ok()) { + std::cerr << "An error occurred during ImageSegmenter creation: " + << image_segmenter_or.status().message(); + return; +} +std::unique_ptr<ImageSegmenter> immage_segmenter = + std::move(image_segmenter_or.value()); + +// Prepare FrameBuffer input from e.g. image RGBA data, width and height: +std::unique_ptr<FrameBuffer> frame_buffer = + CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height}); + +// Run inference: +StatusOr<SegmentationResult> result_or = + immage_segmenter->Segment(*frame_buffer); +// Check if an error occurred. +if (!result_or.ok()) { + std::cerr << "An error occurred during segmentation: " + << result_or.status().message(); + return; +} +SegmentationResult result = result_or.value(); + +// Example value for 'result': +// +// segmentation { +// width: 257 +// height: 257 +// category_mask: "\x00\x01..." +// colored_labels { r: 0 g: 0 b: 0 class_name: "background" } +// colored_labels { r: 128 g: 0 b: 0 class_name: "aeroplane" } +// ... +// colored_labels { r: 128 g: 192 b: 0 class_name: "train" } +// colored_labels { r: 0 g: 64 b: 128 class_name: "tv" } +// } +// +// Where 'category_mask' is a byte buffer of size 'width' x 'height', with the +// value of each pixel representing the class this pixel belongs to (e.g. '\x00' +// means "background", '\x01' means "aeroplane", etc). +// 'colored_labels' provides the label for each possible value, as well as +// suggested RGB components to optionally transform the result into a more +// human-friendly colored image. +// +``` + +A CLI demo tool is available [here][4] for easily trying out this API. + +[1]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc +[2]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc +[3]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc +[4]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc diff --git a/tensorflow_lite_support/cc/task/core/BUILD b/tensorflow_lite_support/cc/task/core/BUILD new file mode 100644 index 00000000..1995dfe3 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/BUILD @@ -0,0 +1,156 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tflite_engine", + srcs = ["tflite_engine.cc"], + hdrs = ["tflite_engine.h"], + deps = [ + ":external_file_handler", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + # The dependency on builtin_ops here is only for the default + # value of the OpResolver parameter: + # std::unique_ptr<tflite::IterableOpResolver> resolver = + # absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>() + # When linking statically, if the client of this library doesn't use + # the default argument, this dependency does not cause all the builtin ops + # to get included in the executable. + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/tools:verifier", + ] + select({ + "//tensorflow_lite_support/cc:tflite_use_c_api": [ + "@org_tensorflow//tensorflow/lite/core/api:verifier", + "@org_tensorflow//tensorflow/lite/c:c_api", + "@org_tensorflow//tensorflow/lite/c:c_api_experimental", + "@org_tensorflow//tensorflow/lite:kernel_api", + "@org_tensorflow//tensorflow/lite:stderr_reporter", + ], + "//conditions:default": [ + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:kernel_api", + ], + }) + [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:tflite_wrapper", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + ], +) + +# This is a duplicate of the above 'tflite_engine' target that is used for +# testing with TFLITE_USE_C_API defined. It should be the same as the target +# above, except that it adds +# testonly = 1, +# defines = ["TFLITE_USE_C_API"], +# and that it resolves the conditional deps from the 'select' as if +# "//tensorflow_lite_support/cc:tflite_use_c_api" was enabled. +# This allows testing the TFLITE_USE_C_API case even when +# '--copt=-DTFLITE_USE_C_API' wasn't passed on the build command line. +cc_library( + name = "tflite_engine_with_c_api_for_test", + testonly = 1, + srcs = ["tflite_engine.cc"], + hdrs = ["tflite_engine.h"], + defines = ["TFLITE_USE_C_API"], + deps = [ + ":external_file_handler", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/tools:verifier", + ] + [ + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/core/api:verifier", + "@org_tensorflow//tensorflow/lite/c:c_api", + "@org_tensorflow//tensorflow/lite/c:c_api_experimental", + "@org_tensorflow//tensorflow/lite:kernel_api", + "@org_tensorflow//tensorflow/lite:stderr_reporter", + ] + [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:tflite_wrapper_with_c_api_for_test", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + ], +) + +cc_library( + name = "base_task_api", + hdrs = ["base_task_api.h"], + deps = [ + ":tflite_engine", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/port:tflite_wrapper", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/c:common", + ], +) + +cc_library( + name = "task_api_factory", + hdrs = ["task_api_factory.h"], + deps = [ + ":base_task_api", + ":tflite_engine", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "@com_google_absl//absl/status", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/kernels:op_macros", + ], +) + +cc_library( + name = "task_utils", + srcs = ["task_utils.cc"], + hdrs = ["task_utils.h"], + deps = [ + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite:type_to_tflitetype", + "@org_tensorflow//tensorflow/lite/kernels:op_macros", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_library( + name = "category", + hdrs = ["category.h"], +) + +cc_library( + name = "external_file_handler", + srcs = ["external_file_handler.cc"], + hdrs = ["external_file_handler.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) diff --git a/tensorflow_lite_support/cc/task/core/base_task_api.h b/tensorflow_lite_support/cc/task/core/base_task_api.h new file mode 100644 index 00000000..a27f785b --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/base_task_api.h @@ -0,0 +1,144 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ + +#include <utility> + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/port/tflite_wrapper.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +namespace tflite { +namespace task { +namespace core { + +class BaseUntypedTaskApi { + public: + explicit BaseUntypedTaskApi(std::unique_ptr<TfLiteEngine> engine) + : engine_{std::move(engine)} {} + + virtual ~BaseUntypedTaskApi() = default; + + TfLiteEngine* GetTfLiteEngine() { return engine_.get(); } + const TfLiteEngine* GetTfLiteEngine() const { return engine_.get(); } + + const metadata::ModelMetadataExtractor* GetMetadataExtractor() const { + return engine_->metadata_extractor(); + } + + protected: + std::unique_ptr<TfLiteEngine> engine_; +}; + +template <class OutputType, class... InputTypes> +class BaseTaskApi : public BaseUntypedTaskApi { + public: + explicit BaseTaskApi(std::unique_ptr<TfLiteEngine> engine) + : BaseUntypedTaskApi(std::move(engine)) {} + // BaseTaskApi is neither copyable nor movable. + BaseTaskApi(const BaseTaskApi&) = delete; + BaseTaskApi& operator=(const BaseTaskApi&) = delete; + + // Cancels the current running TFLite invocation on CPU. + // + // Usually called on a different thread than the one inference is running on. + // Calling Cancel() will cause the underlying TFLite interpreter to return an + // error, which will turn into a `CANCELLED` status and empty results. Calling + // Cancel() at the other time will not take any effect on the current or + // following invocation. It is perfectly fine to run inference again on the + // same instance after a cancelled invocation. If the TFLite inference is + // partially delegated on CPU, logs a warning message and only cancels the + // invocation running on CPU. Other invocation which depends on the output of + // the CPU invocation will not be executed. + void Cancel() { engine_->Cancel(); } + + protected: + // Subclasses need to populate input_tensors from api_inputs. + virtual absl::Status Preprocess( + const std::vector<TfLiteTensor*>& input_tensors, + InputTypes... api_inputs) = 0; + + // Subclasses need to construct OutputType object from output_tensors. + // Original inputs are also provided as they may be needed. + virtual tflite::support::StatusOr<OutputType> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + InputTypes... api_inputs) = 0; + + // Returns (the addresses of) the model's inputs. + std::vector<TfLiteTensor*> GetInputTensors() { return engine_->GetInputs(); } + + // Returns (the addresses of) the model's outputs. + std::vector<const TfLiteTensor*> GetOutputTensors() { + return engine_->GetOutputs(); + } + + // Performs inference using tflite::support::TfLiteInterpreterWrapper + // InvokeWithoutFallback(). + tflite::support::StatusOr<OutputType> Infer(InputTypes... args) { + tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper = + engine_->interpreter_wrapper(); + // Note: AllocateTensors() is already performed by the interpreter wrapper + // at InitInterpreter time (see TfLiteEngine). + RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...)); + absl::Status status = interpreter_wrapper->InvokeWithoutFallback(); + if (!status.ok()) { + return status.GetPayload(tflite::support::kTfLiteSupportPayload) + .has_value() + ? status + : tflite::support::CreateStatusWithPayload(status.code(), + status.message()); + } + return Postprocess(GetOutputTensors(), args...); + } + + // Performs inference using tflite::support::TfLiteInterpreterWrapper + // InvokeWithFallback() to benefit from automatic fallback from delegation to + // CPU where applicable. + tflite::support::StatusOr<OutputType> InferWithFallback(InputTypes... args) { + tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper = + engine_->interpreter_wrapper(); + // Note: AllocateTensors() is already performed by the interpreter wrapper + // at InitInterpreter time (see TfLiteEngine). + RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...)); + auto set_inputs_nop = + [](tflite::task::core::TfLiteEngine::Interpreter* interpreter) + -> absl::Status { + // NOP since inputs are populated at Preprocess() time. + return absl::OkStatus(); + }; + absl::Status status = + interpreter_wrapper->InvokeWithFallback(set_inputs_nop); + if (!status.ok()) { + return status.GetPayload(tflite::support::kTfLiteSupportPayload) + .has_value() + ? status + : tflite::support::CreateStatusWithPayload(status.code(), + status.message()); + } + return Postprocess(GetOutputTensors(), args...); + } +}; + +} // namespace core +} // namespace task +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ diff --git a/tensorflow_lite_support/cc/task/core/category.h b/tensorflow_lite_support/cc/task/core/category.h new file mode 100644 index 00000000..a99f994c --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/category.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_ +#include <string> + +namespace tflite { +namespace task { +namespace core { + +// Result for classification APIs. +struct Category { + std::string class_name; + double score; + Category(const std::string& class_name, double score) + : class_name(class_name), score(score) {} + + friend bool operator==(const Category& lhs, const Category& rhs) { + return lhs.score == rhs.score && lhs.class_name == rhs.class_name; + } + + friend bool operator!=(const Category& lhs, const Category& rhs) { + return !(lhs == rhs); + } +}; + +} // namespace core +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_CATEGORY_H_ diff --git a/tensorflow_lite_support/cc/task/core/external_file_handler.cc b/tensorflow_lite_support/cc/task/core/external_file_handler.cc new file mode 100644 index 00000000..e2150c13 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/external_file_handler.cc @@ -0,0 +1,194 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" + +#include <errno.h> +#include <fcntl.h> +#include <stddef.h> +#include <sys/mman.h> +#include <unistd.h> + +#include <memory> +#include <string> + +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" + +namespace tflite { +namespace task { +namespace core { +namespace { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +// Gets the offset aligned to page size for mapping given files into memory by +// file descriptor correctly, as according to mmap(2), the offset used in mmap +// must be a multiple of sysconf(_SC_PAGE_SIZE). +int64 GetPageSizeAlignedOffset(int64 offset) { + int64 aligned_offset = offset; + int64 page_size = sysconf(_SC_PAGE_SIZE); + if (offset % page_size != 0) { + aligned_offset = offset / page_size * page_size; + } + return aligned_offset; +} + +} // namespace + +/* static */ +StatusOr<std::unique_ptr<ExternalFileHandler>> +ExternalFileHandler::CreateFromExternalFile(const ExternalFile* external_file) { + // Use absl::WrapUnique() to call private constructor: + // https://abseil.io/tips/126. + std::unique_ptr<ExternalFileHandler> handler = + absl::WrapUnique(new ExternalFileHandler(external_file)); + + RETURN_IF_ERROR(handler->MapExternalFile()); + + return handler; +} + +absl::Status ExternalFileHandler::MapExternalFile() { + if (!external_file_.file_content().empty()) { + return absl::OkStatus(); + } + if (external_file_.file_name().empty() && + !external_file_.has_file_descriptor_meta()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "ExternalFile must specify at least one of 'file_content', file_name' " + "or 'file_descriptor_meta'.", + TfLiteSupportStatus::kInvalidArgumentError); + } + // Obtain file descriptor, offset and size. + int fd = -1; + if (!external_file_.file_name().empty()) { + owned_fd_ = open(external_file_.file_name().c_str(), O_RDONLY); + if (owned_fd_ < 0) { + const std::string error_message = absl::StrFormat( + "Unable to open file at %s", external_file_.file_name()); + switch (errno) { + case ENOENT: + return CreateStatusWithPayload( + StatusCode::kNotFound, error_message, + TfLiteSupportStatus::kFileNotFoundError); + case EACCES: + case EPERM: + return CreateStatusWithPayload( + StatusCode::kPermissionDenied, error_message, + TfLiteSupportStatus::kFilePermissionDeniedError); + case EINTR: + return CreateStatusWithPayload(StatusCode::kUnavailable, + error_message, + TfLiteSupportStatus::kFileReadError); + case EBADF: + return CreateStatusWithPayload(StatusCode::kFailedPrecondition, + error_message, + TfLiteSupportStatus::kFileReadError); + default: + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrFormat("%s, errno=%d", error_message, errno), + TfLiteSupportStatus::kFileReadError); + } + } + fd = owned_fd_; + } else { + fd = external_file_.file_descriptor_meta().fd(); + if (fd < 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Provided file descriptor is invalid: %d < 0", fd), + TfLiteSupportStatus::kInvalidArgumentError); + } + buffer_offset_ = external_file_.file_descriptor_meta().offset(); + buffer_size_ = external_file_.file_descriptor_meta().length(); + } + // Get actual file size. Always use 0 as offset to lseek(2) to get the actual + // file size, as SEEK_END returns the size of the file *plus* offset. + size_t file_size = lseek(fd, /*offset=*/0, SEEK_END); + if (file_size <= 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrFormat("Unable to get file size, errno=%d", errno), + TfLiteSupportStatus::kFileReadError); + } + // Deduce buffer size if not explicitly provided through file descriptor. + if (buffer_size_ <= 0) { + buffer_size_ = file_size - buffer_offset_; + } + // Check for out of range issues. + if (file_size <= buffer_offset_) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Provided file offset (%d) exceeds or matches actual " + "file length (%d)", + buffer_offset_, file_size), + TfLiteSupportStatus::kInvalidArgumentError); + } + if (file_size < buffer_size_ + buffer_offset_) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Provided file length + offset (%d) exceeds actual " + "file length (%d)", + buffer_size_ + buffer_offset_, file_size), + TfLiteSupportStatus::kInvalidArgumentError); + } + // If buffer_offset_ is not multiple of sysconf(_SC_PAGE_SIZE), align with + // extra leading bytes and adjust buffer_size_ to account for the extra + // leading bytes. + buffer_aligned_offset_ = GetPageSizeAlignedOffset(buffer_offset_); + buffer_aligned_size_ = buffer_size_ + buffer_offset_ - buffer_aligned_offset_; + // Map into memory. + buffer_ = mmap(/*addr=*/nullptr, buffer_aligned_size_, PROT_READ, MAP_SHARED, + fd, buffer_aligned_offset_); + if (buffer_ == MAP_FAILED) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrFormat("Unable to map file to memory buffer, errno=%d", errno), + TfLiteSupportStatus::kFileMmapError); + } + return absl::OkStatus(); +} + +absl::string_view ExternalFileHandler::GetFileContent() { + if (!external_file_.file_content().empty()) { + return external_file_.file_content(); + } else { + return absl::string_view(static_cast<const char*>(buffer_) + + buffer_offset_ - buffer_aligned_offset_, + buffer_size_); + } +} + +ExternalFileHandler::~ExternalFileHandler() { + if (buffer_ != MAP_FAILED) { + munmap(buffer_, buffer_aligned_size_); + } + if (owned_fd_ >= 0) { + close(owned_fd_); + } +} + +} // namespace core +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/core/external_file_handler.h b/tensorflow_lite_support/cc/task/core/external_file_handler.h new file mode 100644 index 00000000..236d9034 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/external_file_handler.h @@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_EXTERNAL_FILE_HANDLER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_EXTERNAL_FILE_HANDLER_H_ + +#include <memory> + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" + +namespace tflite { +namespace task { +namespace core { + +// Handler providing easy access to the contents of a file specified by an +// ExternalFile proto [1]. Takes care (if needed, depending on the provided +// proto fields) of opening and/or mapping the file in memory at creation time, +// as well as closing and/or unmapping at destruction time. +// +// [1]: support/c/task/core/proto/external_file.proto +class ExternalFileHandler { + public: + // Creates an ExternalFileHandler from the input ExternalFile proto and + // returns a pointer to the new object. Ownership is transferred to the + // caller. Returns an error if the creation failed, which may happen if the + // provided ExternalFile can't be opened or mapped into memory. + // + // Warning: Does not take ownership of `external_file`, which must refer to a + // valid proto that outlives this object. + static tflite::support::StatusOr<std::unique_ptr<ExternalFileHandler>> + CreateFromExternalFile(const ExternalFile* external_file); + + ~ExternalFileHandler(); + + // Returns the content of the ExternalFile as a string_view guaranteed to be + // valid as long as the ExternalFileHandler is alive. + absl::string_view GetFileContent(); + + private: + // Private constructor, called from CreateFromExternalFile(). + explicit ExternalFileHandler(const ExternalFile* external_file) + : external_file_(*external_file) {} + + // Opens (if provided by path) and maps (if provided by path or file + // descriptor) the external file in memory. Does nothing otherwise, as file + // contents are already loaded in memory. + absl::Status MapExternalFile(); + + // Reference to the input ExternalFile. + const ExternalFile& external_file_; + + // The file descriptor of the ExternalFile if provided by path, as it is + // opened and owned by this class. Set to -1 otherwise. + int owned_fd_{-1}; + + // Points to the memory buffer mapped from the file descriptor of the + // ExternalFile, if provided by path or file descriptor. + void* buffer_{}; + + // The mapped memory buffer offset, if any. + int64 buffer_offset_{}; + // The size in bytes of the mapped memory buffer, if any. + int64 buffer_size_{}; + + // As mmap(2) requires the offset to be a multiple of sysconf(_SC_PAGE_SIZE): + + // The aligned mapped memory buffer offset, if any. + int64 buffer_aligned_offset_{}; + // The aligned mapped memory buffer size in bytes taking into account the + // offset shift introduced by buffer_aligned_memory_offset_, if any. + int64 buffer_aligned_size_{}; +}; + +} // namespace core +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_EXTERNAL_FILE_HANDLER_H_ diff --git a/tensorflow_lite_support/cc/task/core/proto/BUILD b/tensorflow_lite_support/cc/task/core/proto/BUILD new file mode 100644 index 00000000..7418e5b2 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/proto/BUILD @@ -0,0 +1,27 @@ +load("//tensorflow_lite_support/cc/port:build_defs.bzl", "support_cc_proto_library") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +proto_library( + name = "external_file_proto", + srcs = ["external_file.proto"], +) + +support_cc_proto_library( + name = "external_file_cc_proto", + srcs = ["external_file.proto"], + deps = [ + ":external_file_proto", + ], +) + +cc_library( + name = "external_file_proto_inc", + hdrs = ["external_file_proto_inc.h"], + deps = [":external_file_cc_proto"], +) diff --git a/tensorflow_lite_support/cc/task/core/proto/external_file.proto b/tensorflow_lite_support/cc/task/core/proto/external_file.proto new file mode 100644 index 00000000..c0a42124 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/proto/external_file.proto @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.core; + + +// Represents external files used by the Task APIs (e.g. TF Lite FlatBuffer or +// plain-text labels file). The files can be specified by one of the following +// three ways: +// +// (1) file contents loaded in `file_content`. +// (2) file path in `file_name`. +// (3) file descriptor through `file_descriptor_meta` as returned by open(2). +// +// If more than one field of these fields is provided, they are used in this +// precedence order. +// Next id: 5 +message ExternalFile { + // The path to the file to open and mmap in memory + optional string file_name = 1; + + // The file contents as a byte array. + optional bytes file_content = 2; + + // The file descriptor to a file opened with open(2), with optional additional + // offset and length information. + optional FileDescriptorMeta file_descriptor_meta = 4; + + // Deprecated field numbers. + reserved 3; +} + +// A proto defining file descriptor metadata for mapping file into memory using +// mmap(2). +message FileDescriptorMeta { + // File descriptor as returned by open(2). + optional int32 fd = 1; + + // Optional length of the mapped memory. If not specified, the actual file + // size is used at runtime. + // + // This is an advanced option, e.g. this can be used on Android to specify the + // length of a given asset obtained from AssetFileDescriptor#getLength(). + optional int64 length = 2; + + // Optional starting offset in the file referred to by the file descriptor + // `fd`. + // + // This is an advanced option, e.g. this can be used on Android to specify the + // offset of a given asset obtained from AssetFileDescriptor#getStartOffset(). + optional int64 offset = 3; +} + diff --git a/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h b/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h new file mode 100644 index 00000000..017aa651 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_EXTERNAL_FILE_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_EXTERNAL_FILE_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/external_file.pb.h" +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_PROTO_EXTERNAL_FILE_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt b/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt new file mode 100644 index 00000000..dafb0fde --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/proto/proto_config.pbtxt @@ -0,0 +1,2 @@ +allow_all: true +optimize_mode: LITE_RUNTIME diff --git a/tensorflow_lite_support/cc/task/core/task_api_factory.h b/tensorflow_lite_support/cc/task/core/task_api_factory.h new file mode 100644 index 00000000..06c3a012 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/task_api_factory.h @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_API_FACTORY_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_API_FACTORY_H_ + +#include <memory> + +#include "absl/status/status.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +namespace tflite { +namespace task { +namespace core { +template <typename T> +using EnableIfBaseUntypedTaskApiSubclass = typename std::enable_if< + std::is_base_of<BaseUntypedTaskApi, T>::value>::type*; + +// Template creator for all subclasses of BaseTaskApi +class TaskAPIFactory { + public: + TaskAPIFactory() = delete; + + template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromBuffer( + const char* buffer_data, size_t buffer_size, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + int num_threads = 1) { + auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver)); + RETURN_IF_ERROR(engine->BuildModelFromFlatBuffer(buffer_data, buffer_size)); + return CreateFromTfLiteEngine<T>(std::move(engine), num_threads); + } + + template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromFile( + const string& file_name, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + int num_threads = 1) { + auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver)); + RETURN_IF_ERROR(engine->BuildModelFromFile(file_name)); + return CreateFromTfLiteEngine<T>(std::move(engine), num_threads); + } + + template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromFileDescriptor( + int file_descriptor, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + int num_threads = 1) { + auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver)); + RETURN_IF_ERROR(engine->BuildModelFromFileDescriptor(file_descriptor)); + return CreateFromTfLiteEngine<T>(std::move(engine), num_threads); + } + + template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> + CreateFromExternalFileProto( + const ExternalFile* external_file, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + int num_threads = 1) { + auto engine = absl::make_unique<TfLiteEngine>(std::move(resolver)); + RETURN_IF_ERROR(engine->BuildModelFromExternalFileProto(external_file)); + return CreateFromTfLiteEngine<T>(std::move(engine), num_threads); + } + + private: + template <typename T, EnableIfBaseUntypedTaskApiSubclass<T> = nullptr> + static tflite::support::StatusOr<std::unique_ptr<T>> CreateFromTfLiteEngine( + std::unique_ptr<TfLiteEngine> engine, int num_threads) { + RETURN_IF_ERROR(engine->InitInterpreter(num_threads)); + return absl::make_unique<T>(std::move(engine)); + } +}; + +} // namespace core +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_API_FACTORY_H_ diff --git a/tensorflow_lite_support/cc/task/core/task_utils.cc b/tensorflow_lite_support/cc/task/core/task_utils.cc new file mode 100644 index 00000000..de733ae9 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/task_utils.cc @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/core/task_utils.h" + +#include <fstream> + +#include "absl/strings/str_cat.h" + +namespace tflite { +namespace task { +namespace core { + +double Dequantize(const TfLiteTensor& tensor, int index) { + int32_t quantized_value = 0; + switch (tensor.type) { + case kTfLiteUInt8: + quantized_value = GetTensorData<uint8_t>(&tensor)[index]; + break; + case kTfLiteInt8: + quantized_value = GetTensorData<int8_t>(&tensor)[index]; + break; + case kTfLiteInt16: + quantized_value = GetTensorData<int16_t>(&tensor)[index]; + break; + default: + TF_LITE_FATAL( + absl::StrCat( + "Invalid tensor type for dequantization ", tensor.name, + ". Requested kTfLiteUInt8, kTfLiteInt8 or kTfLiteInt16, got ", + TfLiteTypeGetName(tensor.type), ".") + .c_str()); + } + return tensor.params.scale * (quantized_value - tensor.params.zero_point); +} + +std::string GetStringAtIndex(const TfLiteTensor* labels, int index) { + const auto& strref = tflite::GetString(labels, index); + return std::string(strref.str, strref.len); +} + +std::string LoadBinaryContent(const char* filename) { + std::ifstream input_file(filename, std::ios::binary | std::ios::ate); + // Find buffer size from input file, and load the buffer. + size_t buffer_size = input_file.tellg(); + std::string buffer(buffer_size, '\0'); + input_file.seekg(0, std::ios::beg); + input_file.read(const_cast<char*>(buffer.c_str()), buffer_size); + return buffer; +} + +} // namespace core +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/core/task_utils.h b/tensorflow_lite_support/cc/task/core/task_utils.h new file mode 100644 index 00000000..c1c3fc31 --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/task_utils.h @@ -0,0 +1,182 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_ + +#include <algorithm> +#include <cstring> +#include <numeric> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/type_to_tflitetype.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace core { + +// Checks if data type of tensor is T and returns the pointer casted to T if +// applicable, returns nullptr if tensor type is not T. +// See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType. +template <typename T> +T* TypedTensor(const TfLiteTensor* tensor_ptr) { + if (tensor_ptr->type == typeToTfLiteType<T>()) { + return reinterpret_cast<T*>(tensor_ptr->data.raw); + } + return nullptr; +} + +// Checks and returns type of a tensor, fails if tensor type is not T. +template <typename T> +T* AssertAndReturnTypedTensor(const TfLiteTensor* tensor) { + if (T* v = TypedTensor<T>(tensor)) return v; + // TODO(b/150903834): throw exceptions instead + TF_LITE_ASSERT(tensor->data.raw); + TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name, + ". Requested ", + TfLiteTypeGetName(typeToTfLiteType<T>()), ", got ", + TfLiteTypeGetName(tensor->type), ".") + .c_str()); +} + +// Populates tensor with array of data, fails if data type doesn't match tensor +// type or has not the same number of elements. +template <typename T> +inline void PopulateTensor(const T* data, int num_elements, + TfLiteTensor* tensor) { + T* v = AssertAndReturnTypedTensor<T>(tensor); + size_t bytes = num_elements * sizeof(T); + // TODO(b/150903834): throw exceptions instead + TF_LITE_ASSERT(tensor->bytes == bytes); + memcpy(v, data, bytes); +} + +// Populates tensor with vector of data, fails if data type doesn't match tensor +// type or has not the same number of elements. +template <typename T> +inline void PopulateTensor(const std::vector<T>& data, TfLiteTensor* tensor) { + return PopulateTensor<T>(data.data(), data.size(), tensor); +} + +template <> +inline void PopulateTensor<std::string>(const std::vector<std::string>& data, + TfLiteTensor* tensor) { + if (tensor->type != kTfLiteString) { + TF_LITE_FATAL(absl::StrCat("Type mismatch for tensor ", tensor->name, + ". Requested STRING, got ", + TfLiteTypeGetName(tensor->type), ".") + .c_str()); + } + tflite::DynamicBuffer input_buf; + for (const auto& value : data) { + input_buf.AddString(value.data(), value.length()); + } + input_buf.WriteToTensorAsVector(tensor); +} + +// Populates tensor one data item, fails if data type doesn't match tensor +// type. +template <typename T> +inline void PopulateTensor(const T& data, TfLiteTensor* tensor) { + T* v = AssertAndReturnTypedTensor<T>(tensor); + *v = data; +} + +template <> +inline void PopulateTensor<std::string>(const std::string& data, + TfLiteTensor* tensor) { + tflite::DynamicBuffer input_buf; + input_buf.AddString(data.data(), data.length()); + input_buf.WriteToTensorAsVector(tensor); +} + +// Populates a vector from the tensor, fails if data type doesn't match tensor +// type. +template <typename T> +inline void PopulateVector(const TfLiteTensor* tensor, std::vector<T>* data) { + AssertAndReturnTypedTensor<T>(tensor); + const T* results = GetTensorData<T>(tensor); + size_t num = tensor->bytes / sizeof(tensor->type); + data->reserve(num); + for (int i = 0; i < num; i++) { + data->emplace_back(results[i]); + } +} + +template <> +inline void PopulateVector<std::string>(const TfLiteTensor* tensor, + std::vector<std::string>* data) { + AssertAndReturnTypedTensor<std::string>(tensor); + int num = GetStringCount(tensor); + data->reserve(num); + for (int i = 0; i < num; i++) { + const auto& strref = tflite::GetString(tensor, i); + data->emplace_back(strref.str, strref.len); + } +} + +// Returns the reversely sorted indices of a vector. +template <typename T> +std::vector<size_t> ReverseSortIndices(const std::vector<T>& v) { + std::vector<size_t> idx(v.size()); + std::iota(idx.begin(), idx.end(), 0); + + std::stable_sort(idx.begin(), idx.end(), + [&v](size_t i1, size_t i2) { return v[i2] < v[i1]; }); + + return idx; +} + +// Returns the original (dequantized) value of the 'index'-th element of +// 'tensor. +double Dequantize(const TfLiteTensor& tensor, int index); + +// Returns the index-th string from the tensor. +std::string GetStringAtIndex(const TfLiteTensor* labels, int index); + +// Loads binary content of a file into a string. +std::string LoadBinaryContent(const char* filename); + +// Gets the tensor from a vector of tensors with name specified inside metadata. +template <typename TensorType> +static TensorType* FindTensorByName( + const std::vector<TensorType*>& tensors, + const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>* + tensor_metadatas, + const std::string& name) { + if (tensor_metadatas == nullptr || + tensor_metadatas->size() != tensors.size()) { + return nullptr; + } + for (int i = 0; i < tensor_metadatas->size(); i++) { + if (strcmp(name.data(), tensor_metadatas->Get(i)->name()->c_str()) == 0) { + return tensors[i]; + } + } + return nullptr; +} + +} // namespace core +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TASK_UTILS_H_ diff --git a/tensorflow_lite_support/cc/task/core/tflite_engine.cc b/tensorflow_lite_support/cc/task/core/tflite_engine.cc new file mode 100644 index 00000000..cf923f6a --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/tflite_engine.cc @@ -0,0 +1,297 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +#include <unistd.h> + +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/builtin_ops.h" +#include "tensorflow/lite/stderr_reporter.h" +#include "tensorflow/lite/tools/verifier.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" + +#if TFLITE_USE_C_API +#include "tensorflow/lite/c/c_api_experimental.h" +#else +#include "tensorflow/lite/kernels/register.h" +#endif + +namespace tflite { +namespace task { +namespace core { + +#ifdef __ANDROID__ +// https://github.com/opencv/opencv/issues/14906 +// "ios_base::Init" object is not a part of Android's "iostream" header (in case +// of clang toolchain, NDK 20). +// +// Ref1: +// https://en.cppreference.com/w/cpp/io/ios_base/Init +// The header <iostream> behaves as if it defines (directly or indirectly) +// an instance of std::ios_base::Init with static storage duration +// +// Ref2: +// https://github.com/gcc-mirror/gcc/blob/gcc-8-branch/libstdc%2B%2B-v3/include/std/iostream#L73-L74 +static std::ios_base::Init s_iostream_initializer; +#endif + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::TfLiteSupportStatus; + +int TfLiteEngine::ErrorReporter::Report(const char* format, va_list args) { + return std::vsnprintf(error_message, sizeof(error_message), format, args); +} + +bool TfLiteEngine::Verifier::Verify(const char* data, int length, + tflite::ErrorReporter* reporter) { + return tflite::Verify(data, length, *op_resolver_, reporter); +} + +#if TFLITE_USE_C_API +TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver) + : model_(nullptr, TfLiteModelDelete), + resolver_(std::move(resolver)), + verifier_(resolver_.get()) {} +#else +TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver) + : model_(), resolver_(std::move(resolver)), verifier_(resolver_.get()) {} +#endif + +std::vector<TfLiteTensor*> TfLiteEngine::GetInputs() { + Interpreter* interpreter = this->interpreter(); + std::vector<TfLiteTensor*> tensors; + int input_count = InputCount(interpreter); + tensors.reserve(input_count); + for (int index = 0; index < input_count; index++) { + tensors.push_back(GetInput(interpreter, index)); + } + return tensors; +} + +std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() { + Interpreter* interpreter = this->interpreter(); + std::vector<const TfLiteTensor*> tensors; + int output_count = OutputCount(interpreter); + tensors.reserve(output_count); + for (int index = 0; index < output_count; index++) { + tensors.push_back(GetOutput(interpreter, index)); + } + return tensors; +} + +// The following function is adapted from the code in +// tflite::FlatBufferModel::VerifyAndBuildFromBuffer. +void TfLiteEngine::VerifyAndBuildModelFromBuffer(const char* buffer_data, + size_t buffer_size) { +#if TFLITE_USE_C_API + // First verify with the base flatbuffers verifier. + // This verifies that the model is a valid flatbuffer model. + flatbuffers::Verifier base_verifier( + reinterpret_cast<const uint8_t*>(buffer_data), buffer_size); + if (!VerifyModelBuffer(base_verifier)) { + TF_LITE_REPORT_ERROR(&error_reporter_, + "The model is not a valid Flatbuffer buffer"); + model_ = nullptr; + return; + } + // Next verify with the extra verifier. This verifies that the model only + // uses operators supported by the OpResolver. + if (!verifier_.Verify(buffer_data, buffer_size, &error_reporter_)) { + model_ = nullptr; + return; + } + // Build the model. + model_.reset(TfLiteModelCreate(buffer_data, buffer_size)); +#else + model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( + buffer_data, buffer_size, &verifier_, &error_reporter_); +#endif +} + +absl::Status TfLiteEngine::InitializeFromModelFileHandler() { + const char* buffer_data = model_file_handler_->GetFileContent().data(); + size_t buffer_size = model_file_handler_->GetFileContent().size(); + VerifyAndBuildModelFromBuffer(buffer_data, buffer_size); + if (model_ == nullptr) { + // To be replaced with a proper switch-case when TF Lite model builder + // returns a `TfLiteStatus` code capturing this type of error. + if (absl::StrContains(error_reporter_.error_message, + "The model is not a valid Flatbuffer")) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, error_reporter_.error_message, + TfLiteSupportStatus::kInvalidFlatBufferError); + } else { + // TODO(b/154917059): augment status with another `TfLiteStatus` code when + // ready. And use a new `TfLiteStatus::kCoreTfLiteError` for the TFLS + // code, instead of the unspecified `kError`. + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrCat( + "Could not build model from the provided pre-loaded flatbuffer: ", + error_reporter_.error_message)); + } + } + + ASSIGN_OR_RETURN( + model_metadata_extractor_, + tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer( + buffer_data, buffer_size)); + + return absl::OkStatus(); +} + +absl::Status TfLiteEngine::BuildModelFromFlatBuffer(const char* buffer_data, + size_t buffer_size) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Model already built"); + } + external_file_.set_file_content(std::string(buffer_data, buffer_size)); + ASSIGN_OR_RETURN( + model_file_handler_, + ExternalFileHandler::CreateFromExternalFile(&external_file_)); + return InitializeFromModelFileHandler(); +} + +absl::Status TfLiteEngine::BuildModelFromFile(const std::string& file_name) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Model already built"); + } + external_file_.set_file_name(file_name); + ASSIGN_OR_RETURN( + model_file_handler_, + ExternalFileHandler::CreateFromExternalFile(&external_file_)); + return InitializeFromModelFileHandler(); +} + +absl::Status TfLiteEngine::BuildModelFromFileDescriptor(int file_descriptor) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Model already built"); + } + external_file_.mutable_file_descriptor_meta()->set_fd(file_descriptor); + ASSIGN_OR_RETURN( + model_file_handler_, + ExternalFileHandler::CreateFromExternalFile(&external_file_)); + return InitializeFromModelFileHandler(); +} + +absl::Status TfLiteEngine::BuildModelFromExternalFileProto( + const ExternalFile* external_file) { + if (model_) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Model already built"); + } + ASSIGN_OR_RETURN(model_file_handler_, + ExternalFileHandler::CreateFromExternalFile(external_file)); + return InitializeFromModelFileHandler(); +} + +absl::Status TfLiteEngine::InitInterpreter(int num_threads) { + tflite::proto::ComputeSettings compute_settings; + return InitInterpreter(compute_settings, num_threads); +} + +#if TFLITE_USE_C_API +const TfLiteRegistration* FindBuiltinOp(void* user_data, + TfLiteBuiltinOperator builtin_op, + int version) { + OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data); + tflite::BuiltinOperator op = static_cast<tflite::BuiltinOperator>(builtin_op); + return op_resolver->FindOp(op, version); +} + +const TfLiteRegistration* FindCustomOp(void* user_data, const char* custom_op, + int version) { + OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data); + return op_resolver->FindOp(custom_op, version); +} +#endif + +absl::Status TfLiteEngine::InitInterpreter( + const tflite::proto::ComputeSettings& compute_settings, int num_threads) { + if (model_ == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInternal, + "TF Lite FlatBufferModel is null. Please make sure to call one of the " + "BuildModelFrom methods before calling InitInterpreter."); + } +#if TFLITE_USE_C_API + std::function<absl::Status(TfLiteDelegate*, + std::unique_ptr<Interpreter, InterpreterDeleter>*)> + initializer = [this, num_threads]( + TfLiteDelegate* optional_delegate, + std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out) + -> absl::Status { + std::unique_ptr<TfLiteInterpreterOptions, + void (*)(TfLiteInterpreterOptions*)> + options{TfLiteInterpreterOptionsCreate(), + TfLiteInterpreterOptionsDelete}; + TfLiteInterpreterOptionsSetOpResolver(options.get(), FindBuiltinOp, + FindCustomOp, resolver_.get()); + TfLiteInterpreterOptionsSetNumThreads(options.get(), num_threads); + if (optional_delegate != nullptr) { + TfLiteInterpreterOptionsAddDelegate(options.get(), optional_delegate); + } + interpreter_out->reset( + TfLiteInterpreterCreateWithSelectedOps(model_.get(), options.get())); + if (*interpreter_out == nullptr) { + return CreateStatusWithPayload( + StatusCode::kAborted, + absl::StrCat("Could not build the TF Lite interpreter: " + "TfLiteInterpreterCreateWithSelectedOps failed: ", + error_reporter_.error_message)); + } + return absl::OkStatus(); + }; +#else + auto initializer = + [this, num_threads]( + std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out) + -> absl::Status { + if (tflite::InterpreterBuilder(*model_, *resolver_)( + interpreter_out, num_threads) != kTfLiteOk) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrCat("Could not build the TF Lite interpreter: ", + error_reporter_.error_message)); + } + if (*interpreter_out == nullptr) { + return CreateStatusWithPayload(StatusCode::kInternal, + "TF Lite interpreter is null."); + } + return absl::OkStatus(); + }; +#endif + + absl::Status status = + interpreter_.InitializeWithFallback(initializer, compute_settings); + + if (!status.ok() && + !status.GetPayload(tflite::support::kTfLiteSupportPayload).has_value()) { + status = CreateStatusWithPayload(status.code(), status.message()); + } + return status; +} + +} // namespace core +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/core/tflite_engine.h b/tensorflow_lite_support/cc/task/core/tflite_engine.h new file mode 100644 index 00000000..30f239da --- /dev/null +++ b/tensorflow_lite_support/cc/task/core/tflite_engine.h @@ -0,0 +1,245 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ + +#include <sys/mman.h> + +#include <memory> + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow_lite_support/cc/port/tflite_wrapper.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" + +// If compiled with -DTFLITE_USE_C_API, this file will use the TF Lite C API +// rather than the TF Lite C++ API. +// TODO(b/168025296): eliminate the '#if TFLITE_USE_C_API' directives here and +// elsewhere and instead use the C API unconditionally, once we have a suitable +// replacement for the features of tflite::support::TfLiteInterpreterWrapper. +#if TFLITE_USE_C_API +#include "tensorflow/lite/c/c_api.h" +#include "tensorflow/lite/core/api/verifier.h" +#include "tensorflow/lite/tools/verifier.h" +#else +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/model.h" +#endif + +namespace tflite { +namespace task { +namespace core { + +// TfLiteEngine encapsulates logic for TFLite model initialization, inference +// and error reporting. +class TfLiteEngine { + public: + // Types. + using InterpreterWrapper = tflite::support::TfLiteInterpreterWrapper; +#if TFLITE_USE_C_API + using Model = struct TfLiteModel; + using Interpreter = struct TfLiteInterpreter; + using ModelDeleter = void (*)(Model*); + using InterpreterDeleter = InterpreterWrapper::InterpreterDeleter; +#else + using Model = tflite::FlatBufferModel; + using Interpreter = tflite::Interpreter; + using ModelDeleter = std::default_delete<Model>; + using InterpreterDeleter = std::default_delete<Interpreter>; +#endif + + // Constructors. + explicit TfLiteEngine( + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + // Model is neither copyable nor movable. + TfLiteEngine(const TfLiteEngine&) = delete; + TfLiteEngine& operator=(const TfLiteEngine&) = delete; + + // Accessors. + static int32_t InputCount(const Interpreter* interpreter) { +#if TFLITE_USE_C_API + return TfLiteInterpreterGetInputTensorCount(interpreter); +#else + return interpreter->inputs().size(); +#endif + } + static int32_t OutputCount(const Interpreter* interpreter) { +#if TFLITE_USE_C_API + return TfLiteInterpreterGetOutputTensorCount(interpreter); +#else + return interpreter->outputs().size(); +#endif + } + static TfLiteTensor* GetInput(Interpreter* interpreter, int index) { +#if TFLITE_USE_C_API + return TfLiteInterpreterGetInputTensor(interpreter, index); +#else + return interpreter->tensor(interpreter->inputs()[index]); +#endif + } + // Same as above, but const. + static const TfLiteTensor* GetInput(const Interpreter* interpreter, + int index) { +#if TFLITE_USE_C_API + return TfLiteInterpreterGetInputTensor(interpreter, index); +#else + return interpreter->tensor(interpreter->inputs()[index]); +#endif + } + static TfLiteTensor* GetOutput(Interpreter* interpreter, int index) { +#if TFLITE_USE_C_API + // We need a const_cast here, because the TF Lite C API only has a non-const + // version of GetOutputTensor (in part because C doesn't support overloading + // on const). + return const_cast<TfLiteTensor*>( + TfLiteInterpreterGetOutputTensor(interpreter, index)); +#else + return interpreter->tensor(interpreter->outputs()[index]); +#endif + } + // Same as above, but const. + static const TfLiteTensor* GetOutput(const Interpreter* interpreter, + int index) { +#if TFLITE_USE_C_API + return TfLiteInterpreterGetOutputTensor(interpreter, index); +#else + return interpreter->tensor(interpreter->outputs()[index]); +#endif + } + + std::vector<TfLiteTensor*> GetInputs(); + std::vector<const TfLiteTensor*> GetOutputs(); + + const Model* model() const { return model_.get(); } + Interpreter* interpreter() { return interpreter_.get(); } + const Interpreter* interpreter() const { return interpreter_.get(); } + InterpreterWrapper* interpreter_wrapper() { return &interpreter_; } + const tflite::metadata::ModelMetadataExtractor* metadata_extractor() const { + return model_metadata_extractor_.get(); + } + + // Builds the TF Lite FlatBufferModel (model_) from the raw FlatBuffer data + // whose ownership remains with the caller, and which must outlive the current + // object. This performs extra verification on the input data using + // tflite::Verify. + absl::Status BuildModelFromFlatBuffer(const char* buffer_data, + size_t buffer_size); + + // Builds the TF Lite model from a given file. + absl::Status BuildModelFromFile(const std::string& file_name); + + // Builds the TF Lite model from a given file descriptor using mmap(2). + absl::Status BuildModelFromFileDescriptor(int file_descriptor); + + // Builds the TFLite model from the provided ExternalFile proto, which must + // outlive the current object. + absl::Status BuildModelFromExternalFileProto( + const ExternalFile* external_file); + + // Initializes interpreter with encapsulated model. + // Note: setting num_threads to -1 has for effect to let TFLite runtime set + // the value. + absl::Status InitInterpreter(int num_threads = 1); + + // Same as above, but allows specifying `compute_settings` for acceleration. + absl::Status InitInterpreter( + const tflite::proto::ComputeSettings& compute_settings, + int num_threads = 1); + + // Cancels the on-going `Invoke()` call if any and if possible. This method + // can be called from a different thread than the one where `Invoke()` is + // running. + void Cancel() { +#if TFLITE_USE_C_API + // NOP. +#else + interpreter_.Cancel(); +#endif + } + + protected: + // TF Lite's DefaultErrorReporter() outputs to stderr. This one captures the + // error into a string so that it can be used to complement tensorflow::Status + // error messages. + struct ErrorReporter : public tflite::ErrorReporter { + // Last error message captured by this error reporter. + char error_message[256]; + int Report(const char* format, va_list args) override; + }; + // Custom error reporter capturing low-level TF Lite error messages. + ErrorReporter error_reporter_; + + private: + // Direct wrapper around tflite::TfLiteVerifier which checks the integrity of + // the FlatBuffer data provided as input. + class Verifier : public tflite::TfLiteVerifier { + public: + explicit Verifier(const tflite::OpResolver* op_resolver) + : op_resolver_(op_resolver) {} + bool Verify(const char* data, int length, + tflite::ErrorReporter* reporter) override; + // The OpResolver to be used to build the TF Lite interpreter. + const tflite::OpResolver* op_resolver_; + }; + + // Verifies that the supplied buffer refers to a valid flatbuffer model, + // and that it uses only operators that are supported by the OpResolver + // that was passed to the TfLiteEngine constructor, and then builds + // the model from the buffer and stores it in 'model_'. + void VerifyAndBuildModelFromBuffer(const char* buffer_data, + size_t buffer_size); + + // Gets the buffer from the file handler; verifies and builds the model + // from the buffer; if successful, sets 'model_metadata_extractor_' to be + // a TF Lite Metadata extractor for the model; and calculates an appropriate + // return Status, + absl::Status InitializeFromModelFileHandler(); + + // TF Lite model and interpreter for actual inference. + std::unique_ptr<Model, ModelDeleter> model_; + + // Interpreter wrapper built from the model. + InterpreterWrapper interpreter_; + + // TFLite Metadata extractor built from the model. + std::unique_ptr<tflite::metadata::ModelMetadataExtractor> + model_metadata_extractor_; + + // Mechanism used by TF Lite to map Ops referenced in the FlatBuffer model to + // actual implementation. Defaults to TF Lite BuiltinOpResolver. + std::unique_ptr<tflite::OpResolver> resolver_; + + // Extra verifier for FlatBuffer input data. + Verifier verifier_; + + // ExternalFile and corresponding ExternalFileHandler for models loaded from + // disk or file descriptor. + ExternalFile external_file_; + std::unique_ptr<ExternalFileHandler> model_file_handler_; +}; + +} // namespace core +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD b/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD new file mode 100644 index 00000000..33b6f6a6 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/BUILD @@ -0,0 +1,118 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "bert_nl_classifier_c_api.h", + "nl_classifier_c_api.h", + "nl_classifier_c_api_common.h", +]) + +cc_library( + name = "nl_classifier", + srcs = [ + "nl_classifier.cc", + ], + hdrs = [ + "nl_classifier.h", + ], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:category", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/text/tokenizers:regex_tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", + "//tensorflow_lite_support/cc/utils:common_utils", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:string", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_library( + name = "nl_classifier_c_api", + srcs = [ + "nl_classifier_c_api.cc", + ], + hdrs = [ + "nl_classifier_c_api.h", + "nl_classifier_c_api_common.h", + ], + visibility = ["//tensorflow_lite_support:__subpackages__"], + deps = [ + ":nl_classifier", + ":nl_classifier_c_api_common", + "//tensorflow_lite_support/cc/task/core:category", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "bert_nl_classifier", + srcs = [ + "bert_nl_classifier.cc", + ], + hdrs = [ + "bert_nl_classifier.h", + ], + deps = [ + ":nl_classifier", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:category", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite:string", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + +cc_library( + name = "bert_nl_classifier_c_api", + srcs = [ + "bert_nl_classifier_c_api.cc", + ], + hdrs = [ + "bert_nl_classifier_c_api.h", + "nl_classifier_c_api_common.h", + ], + visibility = ["//tensorflow_lite_support:__subpackages__"], + deps = [ + ":bert_nl_classifier", + ":nl_classifier_c_api_common", + "//tensorflow_lite_support/cc/task/core:category", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "nl_classifier_c_api_common", + srcs = [ + "nl_classifier_c_api_common.cc", + ], + hdrs = [ + "nl_classifier_c_api_common.h", + ], + visibility = ["//tensorflow_lite_support:__subpackages__"], +) diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc new file mode 100644 index 00000000..d689c9e8 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc @@ -0,0 +1,198 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" + +#include <stddef.h> + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_format.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/string_type.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit; +using ::tflite::support::text::tokenizer::TokenizerResult; +using ::tflite::task::core::FindTensorByName; +using ::tflite::task::core::PopulateTensor; + +namespace { +constexpr char kIdsTensorName[] = "ids"; +constexpr char kMaskTensorName[] = "mask"; +constexpr char kSegmentIdsTensorName[] = "segment_ids"; +constexpr char kScoreTensorName[] = "probability"; +constexpr char kClassificationToken[] = "[CLS]"; +constexpr char kSeparator[] = "[SEP]"; +constexpr int kTokenizerProcessUnitIndex = 0; +} // namespace + +absl::Status BertNLClassifier::Preprocess( + const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) { + auto* input_tensor_metadatas = + GetMetadataExtractor()->GetInputTensorMetadata(); + auto* ids_tensor = + FindTensorByName(input_tensors, input_tensor_metadatas, kIdsTensorName); + auto* mask_tensor = + FindTensorByName(input_tensors, input_tensor_metadatas, kMaskTensorName); + auto* segment_ids_tensor = FindTensorByName( + input_tensors, input_tensor_metadatas, kSegmentIdsTensorName); + + std::string processed_input = input; + absl::AsciiStrToLower(&processed_input); + + TokenizerResult input_tokenize_results; + input_tokenize_results = tokenizer_->Tokenize(processed_input); + + // 2 accounts for [CLS], [SEP] + absl::Span<const std::string> query_tokens = + absl::MakeSpan(input_tokenize_results.subwords.data(), + input_tokenize_results.subwords.data() + + std::min(static_cast<size_t>(kMaxSeqLen - 2), + input_tokenize_results.subwords.size())); + + std::vector<std::string> tokens; + tokens.reserve(2 + query_tokens.size()); + // Start of generating the features. + tokens.push_back(kClassificationToken); + // For query input. + for (const auto& query_token : query_tokens) { + tokens.push_back(query_token); + } + // For Separation. + tokens.push_back(kSeparator); + + std::vector<int> input_ids(kMaxSeqLen, 0); + std::vector<int> input_mask(kMaxSeqLen, 0); + // Convert tokens back into ids and set mask + for (int i = 0; i < tokens.size(); ++i) { + tokenizer_->LookupId(tokens[i], &input_ids[i]); + input_mask[i] = 1; + } + // |<-----------kMaxSeqLen---------->| + // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 + // input_masks 1 1 1... 1 1 0 0... 0 + // segment_ids 0 0 0... 0 0 0 0... 0 + + PopulateTensor(input_ids, ids_tensor); + PopulateTensor(input_mask, mask_tensor); + PopulateTensor(std::vector<int>(kMaxSeqLen, 0), segment_ids_tensor); + + return absl::OkStatus(); +} + +StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& /*input*/) { + if (output_tensors.size() != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("BertNLClassifier models are expected to have only 1 " + "output, found %d", + output_tensors.size()), + TfLiteSupportStatus::kInvalidNumOutputTensorsError); + } + const TfLiteTensor* scores = FindTensorByName( + output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(), + kScoreTensorName); + + // optional labels extracted from metadata + return BuildResults(scores, /*labels=*/nullptr); +} + +StatusOr<std::unique_ptr<BertNLClassifier>> +BertNLClassifier::CreateFromFile( + const std::string& path_to_model_with_metadata, + std::unique_ptr<tflite::OpResolver> resolver) { + std::unique_ptr<BertNLClassifier> bert_nl_classifier; + ASSIGN_OR_RETURN(bert_nl_classifier, + core::TaskAPIFactory::CreateFromFile<BertNLClassifier>( + path_to_model_with_metadata, std::move(resolver))); + RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata()); + return std::move(bert_nl_classifier); +} + +StatusOr<std::unique_ptr<BertNLClassifier>> +BertNLClassifier::CreateFromBuffer( + const char* model_with_metadata_buffer_data, + size_t model_with_metadata_buffer_size, + std::unique_ptr<tflite::OpResolver> resolver) { + std::unique_ptr<BertNLClassifier> bert_nl_classifier; + ASSIGN_OR_RETURN(bert_nl_classifier, + core::TaskAPIFactory::CreateFromBuffer<BertNLClassifier>( + model_with_metadata_buffer_data, + model_with_metadata_buffer_size, std::move(resolver))); + RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata()); + return std::move(bert_nl_classifier); +} + +StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromFd( + int fd, std::unique_ptr<tflite::OpResolver> resolver) { + std::unique_ptr<BertNLClassifier> bert_nl_classifier; + ASSIGN_OR_RETURN( + bert_nl_classifier, + core::TaskAPIFactory::CreateFromFileDescriptor<BertNLClassifier>( + fd, std::move(resolver))); + RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata()); + return std::move(bert_nl_classifier); +} + +absl::Status BertNLClassifier::InitializeFromMetadata() { + // Set up mandatory tokenizer. + const ProcessUnit* tokenizer_process_unit = + GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex); + if (tokenizer_process_unit == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "No input process unit found from metadata.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + ASSIGN_OR_RETURN(tokenizer_, + CreateTokenizerFromProcessUnit(tokenizer_process_unit, + GetMetadataExtractor())); + + // Set up optional label vector. + TrySetLabelFromMetadata( + GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex)) + .IgnoreError(); + return absl::OkStatus(); +} + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h new file mode 100644 index 00000000..0c709ee0 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_ + +#include <stddef.h> + +#include <memory> +#include <string> +#include <vector> + +#include "absl/status/status.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/string_type.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +// Classifier API for NLClassification tasks with Bert models, categorizes +// string into different classes. +// +// The API expects a Bert based TFLite model with metadata populated. +// The metadata should contain the following information: +// - input_process_units for Wordpiece/Sentencepiece Tokenizer +// - 3 input tensors with names "ids", "mask" and "segment_ids" +// - 1 output tensor of type float32[1, 2], with a optionally attached label +// file. If a label file is attached, the file should be a plain text file +// with one label per line, the number of labels should match the number of +// categories the model outputs. + +class BertNLClassifier : public NLClassifier { + public: + using NLClassifier::NLClassifier; + // Max number of tokens to pass to the model. + static constexpr int kMaxSeqLen = 128; + + // Factory function to create a BertNLClassifier from TFLite model with + // metadata. + static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> + CreateFromFile( + const std::string& path_to_model_with_metadata, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + + // Factory function to create a BertNLClassifier from in memory buffer of a + // TFLite model with metadata. + static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> + CreateFromBuffer( + const char* model_with_metadata_buffer_data, + size_t model_with_metadata_buffer_size, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + + // Factory function to create a BertNLClassifier from the file descriptor of a + // TFLite model with metadata. + static tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> + CreateFromFd( + int fd, std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + + protected: + // Run tokenization on input text and construct three input tensors ids, mask + // and segment_ids for the model input. + absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const std::string& input) override; + + // Extract model output and create results with label file attached in + // metadata. If no label file is attached, use output score index as labels. + tflite::support::StatusOr<std::vector<core::Category>> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& input) override; + + private: + // Initialize the API with the tokenizer and label files set in the metadata. + absl::Status InitializeFromMetadata(); + + std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_; +}; + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_H_ diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc new file mode 100644 index 00000000..0decc497 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h" + +#include <memory> + +#include "absl/strings/string_view.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" + +using CategoryCPP = ::tflite::task::core::Category; +using BertNLClassifierCPP = + ::tflite::task::text::nlclassifier::BertNLClassifier; + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct BertNLClassifier { + std::unique_ptr<BertNLClassifierCPP> impl; +}; + +BertNLClassifier* BertNLClassifierFromFile(const char* model_path) { + auto classifier_status = + BertNLClassifierCPP::CreateFromFile(std::string(model_path)); + if (classifier_status.ok()) { + return new BertNLClassifier{.impl = std::unique_ptr<BertNLClassifierCPP>( + dynamic_cast<BertNLClassifierCPP*>( + classifier_status.value().release()))}; + } else { + return nullptr; + } +} + +Categories* BertNLClassifierClassify(const BertNLClassifier* classifier, + const char* text) { + std::vector<CategoryCPP> results = + classifier->impl->Classify(absl::string_view(text).data()); + size_t size = results.size(); + auto* categories = new Category[size]; + + for (size_t i = 0; i < size; ++i) { + categories[i].text = strdup(results[i].class_name.c_str()); + categories[i].score = results[i].score; + } + + auto* c_categories = new Categories; + c_categories->size = size; + c_categories->categories = categories; + return c_categories; +} + +void BertNLClassifierDelete(BertNLClassifier* classifier) { delete classifier; } + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h new file mode 100644 index 00000000..1d0b8b67 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_C_API_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_C_API_H_ + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" +// -------------------------------------------------------------------------- +/// C API for BertNLClassifier. +/// +/// The API leans towards simplicity and uniformity instead of convenience, as +/// most usage will be by language-specific wrappers. It provides largely the +/// same set of functionality as that of the C++ TensorFlow Lite +/// `BertNLClassifier` API, but is useful for shared libraries where having +/// a stable ABI boundary is important. +/// +/// Usage: +/// <pre><code> +/// // Create the model and interpreter options. +/// BertNLClassifier* classifier = +/// BertNLClassifierFromFile("/path/to/model.tflite"); +/// +/// // classification. +/// Categories* categories = Classify(classifier, context, question); +/// +/// // Dispose of the API object. +/// BertNLClassifierrDelete(classifier); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct BertNLClassifier BertNLClassifier; + +// Creates BertNLClassifier from model path, returns nullptr if the file +// doesn't exist or is not a well formatted TFLite model path. +extern BertNLClassifier* BertNLClassifierFromFile(const char* model_path); + +// Invokes the encapsulated TFLite model and classifies the input text. +extern struct Categories* BertNLClassifierClassify( + const BertNLClassifier* classifier, const char* text); + +extern void BertNLClassifierDelete(BertNLClassifier* classifier); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_BERT_NL_CLASSIFIER_C_API_H_ diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc new file mode 100644 index 00000000..1643e3e0 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc @@ -0,0 +1,467 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" + +#include <cstddef> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/cc/utils/common_utils.h" + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +using ::absl::StatusCode; +using ::flatbuffers::Offset; +using ::flatbuffers::Vector; +using ::tflite::TensorMetadata; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::support::text::tokenizer::RegexTokenizer; +using ::tflite::support::text::tokenizer::Tokenizer; +using ::tflite::support::text::tokenizer::TokenizerResult; +using ::tflite::support::utils::LoadVocabFromBuffer; +using ::tflite::task::core::Category; +using ::tflite::task::core::Dequantize; +using ::tflite::task::core::GetStringAtIndex; +using ::tflite::task::core::PopulateTensor; + +namespace { +constexpr int kRegexTokenizerInputTensorIndex = 0; +constexpr int kRegexTokenizerProcessUnitIndex = 0; + +StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile( + const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>* + associated_files, + const tflite::metadata::ModelMetadataExtractor* metadata_extractor) { + if (associated_files == nullptr || associated_files->size() < 1 || + associated_files->Get(0)->name() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid vocab_file from input process unit.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + ASSIGN_OR_RETURN(absl::string_view vocab_buffer, + metadata_extractor->GetAssociatedFile( + associated_files->Get(0)->name()->str())); + return vocab_buffer; +} + +StatusOr<std::unique_ptr<Tokenizer>> CreateRegexTokenizerFromProcessUnit( + const tflite::ProcessUnit* tokenizer_process_unit, + const tflite::metadata::ModelMetadataExtractor* metadata_extractor) { + if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "No metadata or input process unit found.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + if (tokenizer_process_unit->options_type() != + ProcessUnitOptions_RegexTokenizerOptions) { + return CreateStatusWithPayload( + absl::StatusCode::kNotFound, + absl::StrCat( + "Incorrect options_type:", tokenizer_process_unit->options_type(), + " need RegexTokenizerOptions."), + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + const tflite::RegexTokenizerOptions* options = + tokenizer_process_unit->options_as<RegexTokenizerOptions>(); + ASSIGN_OR_RETURN(absl::string_view vocab_buffer, + CheckAndLoadFirstAssociatedFile(options->vocab_file(), + metadata_extractor)); + if (options->delim_regex_pattern() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid delim_regex_pattern from input process unit.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + std::unique_ptr<RegexTokenizer> regex_tokenizer = + absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(), + vocab_buffer.data(), + vocab_buffer.size()); + + int unknown_token_id = 0; + if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "RegexTokenizer doesn't have <UNKNOWN> token.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + int pad_token_id = 0; + if (!regex_tokenizer->GetPadToken(&pad_token_id)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "RegexTokenizer doesn't have <PAD> token.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + return regex_tokenizer; +} + +} // namespace + +const NLClassifierOptions& NLClassifier::GetOptions() const { return options_; } + +absl::Status NLClassifier::TrySetLabelFromMetadata( + const TensorMetadata* metadata) { + if (metadata == nullptr) { + return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, + "Metadata not found for output tensor", + TfLiteSupportStatus::kMetadataNotFoundError); + } + const auto* associated_files = metadata->associated_files(); + if (associated_files == nullptr || associated_files->size() == 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "No label file found for tensor metadata.", + TfLiteSupportStatus::kMetadataMissingLabelsError); + } + const tflite::AssociatedFile* associated_file = + associated_files->Get(kOutputTensorLabelFileIndex); + if (associated_file->type() != AssociatedFileType_TENSOR_AXIS_LABELS) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Incorrect label type found for tensor metadata.", + TfLiteSupportStatus::kMetadataMissingLabelsError); + } + tflite::support::StatusOr<absl::string_view> label_buffer = + GetMetadataExtractor()->GetAssociatedFile( + associated_files->Get(kOutputTensorIndex)->name()->str()); + if (label_buffer.ok()) { + labels_vector_ = + absl::make_unique<std::vector<std::string>>(LoadVocabFromBuffer( + label_buffer.value().data(), label_buffer.value().size())); + return absl::OkStatus(); + } else { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Failed to extract label file from metadata.", + TfLiteSupportStatus::kMetadataMissingLabelsError); + } +} + +std::vector<Category> NLClassifier::Classify(const std::string& text) { + // The NLClassifier implementation for Preprocess() and Postprocess() never + // returns errors: just call value(). + return Infer(text).value(); +} + +absl::Status NLClassifier::Preprocess( + const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) { + TfLiteTensor* input_tensor = FindTensorWithNameOrIndex( + input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(), + options_.input_tensor_name, options_.input_tensor_index); + if (input_tensor == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "No input tensor found from NLClassifierOptions.", + TfLiteSupportStatus::kInputTensorNotFoundError); + } + + if (HasRegexTokenizerMetadata()) { + // |<-------sentence_length-------->| + // input_tensor <START>, t1, t2... <PAD>, <PAD>... + // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not + // found in tokenizer vocab. + TokenizerResult result = tokenizer_->Tokenize(input); + + size_t max_sentence_length = input_tensor->dims->size == 2 + ? input_tensor->dims->data[1] + : input_tensor->dims->data[0]; + + int unknown_token_id = 0; + tokenizer_->GetUnknownToken(&unknown_token_id); + + int pad_token_id = 0; + tokenizer_->GetPadToken(&pad_token_id); + + std::vector<int> input_tokens(max_sentence_length, pad_token_id); + int start_token_id = 0; + size_t input_token_index = 0; + if (tokenizer_->GetStartToken(&start_token_id)) { + input_tokens[0] = start_token_id; + input_token_index = 1; + } + + for (size_t i = 0; (i < result.subwords.size()) && + (input_token_index < max_sentence_length); + ++i, ++input_token_index) { + const std::string& token = result.subwords[i]; + int token_id = 0; + if (tokenizer_->LookupId(token, &token_id)) { + input_tokens[input_token_index] = token_id; + } else { + input_tokens[input_token_index] = unknown_token_id; + } + } + + PopulateTensor(input_tokens, input_tensor); + } else { + PopulateTensor(input, input_tensor); + } + return absl::OkStatus(); +} + +StatusOr<std::vector<Category>> NLClassifier::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& /*input*/) { + return BuildResults( + FindTensorWithNameOrIndex( + output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(), + options_.output_score_tensor_name, + options_.output_score_tensor_index), + FindTensorWithNameOrIndex( + output_tensors, GetMetadataExtractor()->GetInputTensorMetadata(), + options_.output_label_tensor_name, + options_.output_label_tensor_index)); +} + +std::vector<Category> NLClassifier::BuildResults(const TfLiteTensor* scores, + const TfLiteTensor* labels) { + bool use_index_as_labels = (labels_vector_ == nullptr) && (labels == nullptr); + // Some models output scores with transposed shape [1, categories] + int categories = + scores->dims->size == 2 ? scores->dims->data[1] : scores->dims->data[0]; + + std::vector<Category> predictions; + predictions.reserve(categories); + + bool should_dequantize = scores->type == kTfLiteUInt8 || + scores->type == kTfLiteInt8 || + scores->type == kTfLiteInt16; + for (int index = 0; index < categories; index++) { + std::string label; + if (use_index_as_labels) { + label = std::to_string(index); + } else if (labels_vector_ == nullptr) { + if (labels->type == kTfLiteString) { + label = GetStringAtIndex(labels, index); + } else if (labels->type == kTfLiteInt32) { + label = std::to_string(GetTensorData<int>(labels)[index]); + } + } else { + label = (*labels_vector_)[index]; + } + if (should_dequantize) { + predictions.push_back(Category(label, Dequantize(*scores, index))); + } else if (scores->type == kTfLiteBool) { + predictions.push_back( + Category(label, GetTensorData<bool>(scores)[index] ? 1.0 : 0.0)); + } else { + predictions.push_back( + Category(label, scores->type == kTfLiteFloat32 + ? GetTensorData<float>(scores)[index] + : GetTensorData<double>(scores)[index])); + } + } + + return predictions; +} +absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) { + options_ = options; + // input tensor should be type STRING + auto input_tensor = FindTensorWithNameOrIndex( + GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(), + options.input_tensor_name, options.input_tensor_index); + if (input_tensor == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("No input tensor found with name ", + options.input_tensor_name, " or at index ", + options.input_tensor_index), + TfLiteSupportStatus::kInputTensorNotFoundError); + } + if (HasRegexTokenizerMetadata()) { + if (input_tensor->type != kTfLiteInt32) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("Type mismatch for input tensor ", input_tensor->name, + ". Requested INT32, got ", + TfLiteTypeGetName(input_tensor->type), "."), + TfLiteSupportStatus::kInvalidInputTensorTypeError); + } + RETURN_IF_ERROR(SetupRegexTokenizer()); + } else { + if (input_tensor->type != kTfLiteString) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("Type mismatch for input tensor ", input_tensor->name, + ". Requested STRING, got ", + TfLiteTypeGetName(input_tensor->type), "."), + TfLiteSupportStatus::kInvalidInputTensorTypeError); + } + } + + // output score tensor should be type + // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL + std::vector<const TfLiteTensor*> output_tensors = GetOutputTensors(); + const Vector<Offset<TensorMetadata>>* output_tensor_metadatas = + GetMetadataExtractor()->GetOutputTensorMetadata(); + + const auto scores = FindTensorWithNameOrIndex( + output_tensors, output_tensor_metadatas, options.output_score_tensor_name, + options.output_score_tensor_index); + if (scores == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("No output score tensor found with name ", + options.output_score_tensor_name, " or at index ", + options.output_score_tensor_index), + TfLiteSupportStatus::kOutputTensorNotFoundError); + } + static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteInt8, + kTfLiteInt16, kTfLiteFloat32, + kTfLiteFloat64, kTfLiteBool}; + if (!absl::c_linear_search(valid_types, scores->type)) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("Type mismatch for score tensor ", scores->name, + ". Requested one of these types: " + "INT8/UINT8/INT16/FLOAT32/FLOAT64/BOOL, got ", + TfLiteTypeGetName(scores->type), "."), + TfLiteSupportStatus::kInvalidOutputTensorTypeError); + } + + // Extract associated label file from output score tensor if one exists, a + // well-formatted metadata should have same number of tensors with the model. + if (output_tensor_metadatas && + output_tensor_metadatas->size() == output_tensors.size()) { + for (int i = 0; i < output_tensor_metadatas->size(); ++i) { + const tflite::TensorMetadata* metadata = output_tensor_metadatas->Get(i); + if ((metadata->name() && metadata->name()->string_view() == + options.output_score_tensor_name) || + i == options.output_score_tensor_index) { + if (TrySetLabelFromMetadata(metadata).ok()) { + return absl::OkStatus(); + } + } + } + } + + // If labels_vector_ is not set up from metadata, try register output label + // tensor from options. + if (labels_vector_ == nullptr) { + // output label tensor should be type STRING or INT32 if the one exists + auto labels = FindTensorWithNameOrIndex( + output_tensors, output_tensor_metadatas, + options.output_label_tensor_name, options.output_label_tensor_index); + if (labels != nullptr && labels->type != kTfLiteString && + labels->type != kTfLiteInt32) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("Type mismatch for label tensor ", scores->name, + ". Requested STRING or INT32, got ", + TfLiteTypeGetName(scores->type), "."), + TfLiteSupportStatus::kInvalidOutputTensorTypeError); + } + } + return absl::OkStatus(); +} + +StatusOr<std::unique_ptr<NLClassifier>> +NLClassifier::CreateFromBufferAndOptions( + const char* model_buffer_data, size_t model_buffer_size, + const NLClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + std::unique_ptr<NLClassifier> nl_classifier; + ASSIGN_OR_RETURN( + nl_classifier, + core::TaskAPIFactory::CreateFromBuffer<NLClassifier>( + model_buffer_data, model_buffer_size, std::move(resolver))); + RETURN_IF_ERROR(nl_classifier->Initialize(options)); + return std::move(nl_classifier); +} + +StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions( + const std::string& path_to_model, const NLClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + std::unique_ptr<NLClassifier> nl_classifier; + ASSIGN_OR_RETURN(nl_classifier, + core::TaskAPIFactory::CreateFromFile<NLClassifier>( + path_to_model, std::move(resolver))); + RETURN_IF_ERROR(nl_classifier->Initialize(options)); + return std::move(nl_classifier); +} + +StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions( + int fd, const NLClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + std::unique_ptr<NLClassifier> nl_classifier; + ASSIGN_OR_RETURN(nl_classifier, + core::TaskAPIFactory::CreateFromFileDescriptor<NLClassifier>( + fd, std::move(resolver))); + RETURN_IF_ERROR(nl_classifier->Initialize(options)); + return std::move(nl_classifier); +} + +bool NLClassifier::HasRegexTokenizerMetadata() { + const TensorMetadata* input_tensor_metadata = + GetMetadataExtractor()->GetInputTensorMetadata( + kRegexTokenizerInputTensorIndex); + if (input_tensor_metadata == nullptr) { + return false; + } + tflite::support::StatusOr<const tflite::ProcessUnit*> status = + GetMetadataExtractor()->FindFirstProcessUnit( + *input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions); + return status.ok() ? status.value() != nullptr : false; +} + +absl::Status NLClassifier::SetupRegexTokenizer() { + ASSIGN_OR_RETURN( + std::unique_ptr<Tokenizer> base_tokenizer, + CreateRegexTokenizerFromProcessUnit( + GetMetadataExtractor() + ->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex) + ->process_units() + ->Get(kRegexTokenizerProcessUnitIndex), + GetMetadataExtractor())); + + tokenizer_ = std::unique_ptr<RegexTokenizer>( + dynamic_cast<RegexTokenizer*>(base_tokenizer.release())); + + return absl::OkStatus(); +} + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h new file mode 100644 index 00000000..2a9573a1 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h @@ -0,0 +1,181 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_ + +#include <stddef.h> +#include <string.h> + +#include <memory> +#include <string> +#include <vector> + +#include "absl/status/status.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/string_type.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h" + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +// Options to identify input and output tensors of the model +struct NLClassifierOptions { + int input_tensor_index = 0; + int output_score_tensor_index = 0; + // By default there is no output label tensor. The label file can be attached + // to the output score tensor metadata. + int output_label_tensor_index = -1; + std::string input_tensor_name = "INPUT"; + std::string output_score_tensor_name = "OUTPUT_SCORE"; + std::string output_label_tensor_name = "OUTPUT_LABEL"; +}; + +// Classifier API for NLClassification tasks, categorizes string into different +// classes. +// +// The API expects a TFLite model with the following input/output tensor: +// Input tensor: +// (kTfLiteString) - input of the model, accepts a string. +// or +// (kTfLiteInt32) - input of the model, accepts a tokenized +// indices of a string input. A RegexTokenizer needs to be set up in the input +// tensor's metadata. +// Output score tensor: +// (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/ +// kTfLiteFloat64/kTfLiteBool) +// - output scores for each class, if type is one of the Int types, +// dequantize it to double, if type is kTfLiteBool, convert the values to +// 0.0 and 1.0 respectively +// - can have an optional associated file in metadata for labels, the file +// should be a plain text file with one label per line, the number of +// labels should match the number of categories the model outputs. +// Output label tensor: optional +// (kTfLiteString/kTfLiteInt32) +// - output classname for each class, should be of the same length with +// scores. If this tensor is not present, the API uses score indices as +// classnames. +// - will be ignored if output score tensor already has an associated label +// file. +// +// By default the API tries to find the input/output tensors with default +// configurations in NLClassifierOptions, with tensor name prioritized over +// tensor index. The option is configurable for different TFLite models. +class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, + const std::string&> { + public: + using BaseTaskApi::BaseTaskApi; + + // Creates a NLClassifier from TFLite model buffer. + static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> + CreateFromBufferAndOptions( + const char* model_buffer_data, size_t model_buffer_size, + const NLClassifierOptions& options = {}, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + + // Creates a NLClassifier from TFLite model file. + static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> + CreateFromFileAndOptions( + const std::string& path_to_model, const NLClassifierOptions& options = {}, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + + // Creates a NLClassifier from TFLite model file descriptor. + static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> + CreateFromFdAndOptions( + int fd, const NLClassifierOptions& options = {}, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + + // Performs classification on a string input, returns classified results. + std::vector<core::Category> Classify(const std::string& text); + + protected: + static constexpr int kOutputTensorIndex = 0; + static constexpr int kOutputTensorLabelFileIndex = 0; + + absl::Status Initialize(const NLClassifierOptions& options); + const NLClassifierOptions& GetOptions() const; + + // Try to extract attached label file from metadata and initialize + // labels_vector_, return error if metadata type is incorrect or no label file + // is attached in metadata. + absl::Status TrySetLabelFromMetadata(const TensorMetadata* metadata); + + // Pass through the input text into model's input tensor. + absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const std::string& input) override; + + // Extract model output and create results with output label tensor or label + // file attached in metadata. If no output label tensor or label file is + // found, use output score index as labels. + tflite::support::StatusOr<std::vector<core::Category>> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& input) override; + + std::vector<core::Category> BuildResults(const TfLiteTensor* scores, + const TfLiteTensor* labels); + + // Gets the tensor from a vector of tensors by checking tensor name first and + // tensor index second, return nullptr if no tensor is found. + template <typename TensorType> + static TensorType* FindTensorWithNameOrIndex( + const std::vector<TensorType*>& tensors, + const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>* + metadata_array, + const std::string& name, int index) { + if (metadata_array != nullptr && metadata_array->size() == tensors.size()) { + for (int i = 0; i < metadata_array->size(); i++) { + if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) { + return tensors[i]; + } + } + } + + for (TensorType* tensor : tensors) { + if (tensor->name == name) { + return tensor; + } + } + return index >= 0 && index < tensors.size() ? tensors[index] : nullptr; + } + + private: + bool HasRegexTokenizerMetadata(); + absl::Status SetupRegexTokenizer(); + + NLClassifierOptions options_; + // labels vector initialized from output tensor's associated file, if one + // exists. + std::unique_ptr<std::vector<std::string>> labels_vector_; + std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_; +}; + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_ diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc new file mode 100644 index 00000000..3f7827d8 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.cc @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h" + +#include <memory> + +#include "absl/strings/string_view.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" + +using CategoryCPP = ::tflite::task::core::Category; +using NLClassifierCPP = ::tflite::task::text::nlclassifier::NLClassifier; +using NLClassifierOptionsCPP = + ::tflite::task::text::nlclassifier::NLClassifierOptions; + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct NLClassifier { + std::unique_ptr<NLClassifierCPP> impl; +}; + +NLClassifier* NLClassifierFromFileAndOptions( + const char* model_path, const NLClassifierOptions* options) { + auto classifier_status = NLClassifierCPP::CreateFromFileAndOptions( + std::string(model_path), + { + .input_tensor_index = options->input_tensor_index, + .output_score_tensor_index = options->output_score_tensor_index, + .output_label_tensor_index = options->output_label_tensor_index, + .input_tensor_name = !options->input_tensor_name + ? "" + : std::string(options->input_tensor_name), + .output_score_tensor_name = + !options->output_score_tensor_name + ? "" + : std::string(options->output_score_tensor_name), + .output_label_tensor_name = + !options->output_label_tensor_name + ? "" + : std::string(options->output_label_tensor_name), + }); + + if (classifier_status.ok()) { + return new NLClassifier{ + .impl = std::unique_ptr<NLClassifierCPP>(dynamic_cast<NLClassifierCPP*>( + classifier_status.value().release()))}; + } else { + return nullptr; + } +} + +Categories* NLClassifierClassify(const NLClassifier* classifier, + const char* text) { + std::vector<CategoryCPP> results = + classifier->impl->Classify(absl::string_view(text).data()); + size_t size = results.size(); + auto* categories = new Category[size]; + + for (size_t i = 0; i < size; ++i) { + categories[i].text = strdup(results[i].class_name.c_str()); + categories[i].score = results[i].score; + } + + auto* c_categories = new Categories; + c_categories->size = size; + c_categories->categories = categories; + return c_categories; +} + +void NLClassifierDelete(NLClassifier* classifier) { delete classifier; } + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h new file mode 100644 index 00000000..1af93f29 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_H_ + + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" +// -------------------------------------------------------------------------- +/// C API for NLClassifier. +/// +/// The API leans towards simplicity and uniformity instead of convenience, as +/// most usage will be by language-specific wrappers. It provides largely the +/// same set of functionality as that of the C++ TensorFlow Lite `NLClassifier` +/// API, but is useful for shared libraries where having a stable ABI boundary +/// is important. +/// +/// Usage: +/// <pre><code> +/// // Create the model and interpreter options. +/// NLClassifier* classifier = NLClassifierFromFileAndOptions( +/// "/path/to/model.tflite"); +/// +/// // classification. +/// Categories* categories = Classify(classifier, context, question); +/// +/// // Dispose of the API object. +/// NLClassifierDelete(classifier); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct NLClassifier NLClassifier; + +struct NLClassifierOptions { + int input_tensor_index; + int output_score_tensor_index; + int output_label_tensor_index; + const char* input_tensor_name; + const char* output_score_tensor_name; + const char* output_label_tensor_name; +}; + +// Creates NLClassifier from model path and options, returns nullptr if the file +// doesn't exist or is not a well formatted TFLite model path. +extern NLClassifier* NLClassifierFromFileAndOptions( + const char* model_path, + const struct NLClassifierOptions* options); + +// Invokes the encapsulated TFLite model and classifies the input text. +extern struct Categories* NLClassifierClassify(const NLClassifier* classifier, + const char* text); + +extern void NLClassifierDelete(NLClassifier* classifier); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_H_ diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc new file mode 100644 index 00000000..3beb658a --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" + + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +void NLClassifierCategoriesDelete(Categories* categories) { + delete[] categories->categories; + delete categories; +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h new file mode 100644 index 00000000..663c873c --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_ + +// Common structs shared between NLClassifier APIs +// +/// // Dispose of the Categories object. +/// NLClassifierCategoriesDelete(categories); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct Category { + char* text; + double score; +}; + +struct Categories { + int size; + struct Category* categories; +}; + +extern void NLClassifierCategoriesDelete(struct Categories* categories); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_C_API_COMMON_H_ diff --git a/tensorflow_lite_support/cc/task/text/qa/BUILD b/tensorflow_lite_support/cc/task/text/qa/BUILD new file mode 100644 index 00000000..49ad5a1f --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/BUILD @@ -0,0 +1,61 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "bert_qa_c_api.h", +]) + +cc_library( + name = "question_answerer", + hdrs = [ + "question_answerer.h", + ], + deps = [ + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + ], +) + +cc_library( + name = "bert_question_answerer", + srcs = [ + "bert_question_answerer.cc", + ], + hdrs = [ + "bert_question_answerer.h", + ], + deps = [ + ":question_answerer", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:sentencepiece_tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "bert_qa_c_api", + srcs = [ + "bert_qa_c_api.cc", + ], + hdrs = [ + "bert_qa_c_api.h", + ], + visibility = ["//tensorflow_lite_support:__subpackages__"], + deps = [ + ":bert_question_answerer", + ":question_answerer", + ], +) diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc b/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc new file mode 100644 index 00000000..5fafb59d --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.cc @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h" + +#include <memory> + +#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h" +#include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h" + +using BertQuestionAnswererCPP = ::tflite::task::text::qa::BertQuestionAnswerer; +using QaAnswerCPP = ::tflite::task::text::qa::QaAnswer; + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +struct BertQuestionAnswerer { + std::unique_ptr<BertQuestionAnswererCPP> impl; +}; + +BertQuestionAnswerer* BertQuestionAnswererFromFile(const char* model_path) { + auto bert_qa_status = + BertQuestionAnswererCPP::CreateFromFile(std::string(model_path)); + if (bert_qa_status.ok()) { + return new BertQuestionAnswerer{ + .impl = std::unique_ptr<BertQuestionAnswererCPP>( + dynamic_cast<BertQuestionAnswererCPP*>( + bert_qa_status.value().release()))}; + } else { + return nullptr; + } +} + +QaAnswers* BertQuestionAnswererAnswer( + const BertQuestionAnswerer* question_answerer, const char* context, + const char* question) { + std::vector<QaAnswerCPP> answers = question_answerer->impl->Answer( + absl::string_view(context).data(), absl::string_view(question).data()); + size_t size = answers.size(); + auto* qa_answers = new QaAnswer[size]; + + for (size_t i = 0; i < size; ++i) { + qa_answers[i].start = answers[i].pos.start; + qa_answers[i].end = answers[i].pos.end; + qa_answers[i].logit = answers[i].pos.logit; + qa_answers[i].text = strdup(answers[i].text.c_str()); + } + + auto* c_answers = new QaAnswers; + c_answers->size = size; + c_answers->answers = qa_answers; + return c_answers; +} + +void BertQuestionAnswererDelete(BertQuestionAnswerer* bert_question_answerer) { + delete bert_question_answerer; +} + +void BertQuestionAnswererQaAnswersDelete(QaAnswers* qa_answers) { + delete[] qa_answers->answers; + delete qa_answers; +} + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h b/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h new file mode 100644 index 00000000..7fd36948 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QA_C_API_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QA_C_API_H_ + +// -------------------------------------------------------------------------- +/// C API for BertQuestionAnswerer. +/// +/// The API leans towards simplicity and uniformity instead of convenience, as +/// most usage will be by language-specific wrappers. It provides largely the +/// same set of functionality as that of the C++ TensorFlow Lite +/// `BertQuestionAnswerer` API, but is useful for shared libraries where having +/// a stable ABI boundary is important. +/// +/// Usage: +/// <pre><code> +/// // Create the model and interpreter options. +/// BertQuestionAnswerer* qa_answerer = +/// BertQuestionAnswererFromFile("/path/to/model.tflite"); +/// +/// // answer a question. +/// QaAnswers* answers = Answer(qa_answerer, context, question); +/// +/// // Dispose of the API and QaAnswers objects. +/// BertQuestionAnswererDelete(qa_answerer); +/// BertQuestionAnswererQaAnswersDelete(answers); + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct BertQuestionAnswerer BertQuestionAnswerer; + +struct QaAnswer { + int start; + int end; + float logit; + char* text; +}; + +struct QaAnswers { + int size; + struct QaAnswer* answers; +}; + +// Creates BertQuestionAnswerer from model path, returns nullptr if the file +// doesn't exist or is not a well formatted TFLite model path. +extern BertQuestionAnswerer* BertQuestionAnswererFromFile( + const char* model_path); + +// Invokes the encapsulated TFLite model and answers a question based on +// context. +extern struct QaAnswers* BertQuestionAnswererAnswer( + const BertQuestionAnswerer* question_answerer, const char* context, + const char* question); + +extern void BertQuestionAnswererDelete( + BertQuestionAnswerer* bert_question_answerer); + +extern void BertQuestionAnswererQaAnswersDelete(struct QaAnswers* qa_answers); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QA_C_API_H_ diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc new file mode 100644 index 00000000..aa7ffef3 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.cc @@ -0,0 +1,393 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h" + +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace text { +namespace qa { + +constexpr char kIdsTensorName[] = "ids"; +constexpr char kMaskTensorName[] = "mask"; +constexpr char kSegmentIdsTensorName[] = "segment_ids"; +constexpr char kEndLogitsTensorName[] = "end_logits"; +constexpr char kStartLogitsTensorName[] = "start_logits"; + +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::support::text::tokenizer::BertTokenizer; +using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit; +using ::tflite::support::text::tokenizer::SentencePieceTokenizer; +using ::tflite::support::text::tokenizer::TokenizerResult; +using ::tflite::task::core::FindTensorByName; +using ::tflite::task::core::PopulateTensor; +using ::tflite::task::core::PopulateVector; +using ::tflite::task::core::ReverseSortIndices; + +namespace { +constexpr int kTokenizerProcessUnitIndex = 0; +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateFromFile( + const std::string& path_to_model_with_metadata) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( + path_to_model_with_metadata, + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + RETURN_IF_ERROR(api_to_init->InitializeFromMetadata()); + return api_to_init; +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateFromBuffer( + const char* model_with_metadata_buffer_data, + size_t model_with_metadata_buffer_size) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>( + model_with_metadata_buffer_data, model_with_metadata_buffer_size, + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + RETURN_IF_ERROR(api_to_init->InitializeFromMetadata()); + return api_to_init; +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd( + int fd) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromFileDescriptor<BertQuestionAnswerer>( + fd, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + RETURN_IF_ERROR(api_to_init->InitializeFromMetadata()); + return api_to_init; +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateBertQuestionAnswererFromFile( + const std::string& path_to_model, const std::string& path_to_vocab) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( + path_to_model, + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + api_to_init->InitializeBertTokenizer(path_to_vocab); + return api_to_init; +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer( + const char* model_buffer_data, size_t model_buffer_size, + const char* vocab_buffer_data, size_t vocab_buffer_size) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>( + model_buffer_data, model_buffer_size, + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + api_to_init->InitializeBertTokenizerFromBinary(vocab_buffer_data, + vocab_buffer_size); + return api_to_init; +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile( + const std::string& path_to_model, const std::string& path_to_spmodel) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>( + path_to_model, + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + api_to_init->InitializeSentencepieceTokenizer(path_to_spmodel); + return api_to_init; +} + +StatusOr<std::unique_ptr<QuestionAnswerer>> +BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer( + const char* model_buffer_data, size_t model_buffer_size, + const char* spmodel_buffer_data, size_t spmodel_buffer_size) { + std::unique_ptr<BertQuestionAnswerer> api_to_init; + ASSIGN_OR_RETURN( + api_to_init, + core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>( + model_buffer_data, model_buffer_size, + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), + kNumLiteThreads)); + api_to_init->InitializeSentencepieceTokenizerFromBinary(spmodel_buffer_data, + spmodel_buffer_size); + return api_to_init; +} + +std::vector<QaAnswer> BertQuestionAnswerer::Answer( + const std::string& context, const std::string& question) { + // The BertQuestionAnswererer implementation for Preprocess() and + // Postprocess() never returns errors: just call value(). + return Infer(context, question).value(); +} + +absl::Status BertQuestionAnswerer::Preprocess( + const std::vector<TfLiteTensor*>& input_tensors, const std::string& context, + const std::string& query) { + auto* input_tensor_metadatas = + GetMetadataExtractor()->GetInputTensorMetadata(); + TfLiteTensor* ids_tensor = + input_tensor_metadatas + ? FindTensorByName(input_tensors, input_tensor_metadatas, + kIdsTensorName) + : input_tensors[0]; + TfLiteTensor* mask_tensor = + input_tensor_metadatas + ? FindTensorByName(input_tensors, input_tensor_metadatas, + kMaskTensorName) + : input_tensors[1]; + TfLiteTensor* segment_ids_tensor = + input_tensor_metadatas + ? FindTensorByName(input_tensors, input_tensor_metadatas, + kSegmentIdsTensorName) + : input_tensors[2]; + + token_to_orig_map_.clear(); + + // The orig_tokens is used for recovering the answer string from the index, + // while the processed_tokens is lower-cased and used to generate input of + // the model. + orig_tokens_ = absl::StrSplit(context, absl::ByChar(' '), absl::SkipEmpty()); + std::vector<std::string> processed_tokens(orig_tokens_); + + std::string processed_query = query; + if (kUseLowerCase) { + for (auto& token : processed_tokens) { + absl::AsciiStrToLower(&token); + } + absl::AsciiStrToLower(&processed_query); + } + + TokenizerResult query_tokenize_results; + query_tokenize_results = tokenizer_->Tokenize(processed_query); + + std::vector<std::string> query_tokens = query_tokenize_results.subwords; + if (query_tokens.size() > kMaxQueryLen) { + query_tokens.resize(kMaxQueryLen); + } + + // Example: + // context: tokenize me please + // all_doc_tokens: token ##ize me plea ##se + // token_to_orig_index: [0, 0, 1, 2, 2] + + std::vector<std::string> all_doc_tokens; + std::vector<int> token_to_orig_index; + for (size_t i = 0; i < processed_tokens.size(); i++) { + const std::string& token = processed_tokens[i]; + std::vector<std::string> sub_tokens = tokenizer_->Tokenize(token).subwords; + for (const std::string& sub_token : sub_tokens) { + token_to_orig_index.emplace_back(i); + all_doc_tokens.emplace_back(sub_token); + } + } + + // -3 accounts for [CLS], [SEP] and [SEP]. + int max_context_len = kMaxSeqLen - query_tokens.size() - 3; + if (all_doc_tokens.size() > max_context_len) { + all_doc_tokens.resize(max_context_len); + } + + std::vector<std::string> tokens; + tokens.reserve(3 + query_tokens.size() + all_doc_tokens.size()); + std::vector<int> segment_ids; + segment_ids.reserve(kMaxSeqLen); + + // Start of generating the features. + tokens.emplace_back("[CLS]"); + segment_ids.emplace_back(0); + + // For query input. + for (const auto& query_token : query_tokens) { + tokens.emplace_back(query_token); + segment_ids.emplace_back(0); + } + + // For Separation. + tokens.emplace_back("[SEP]"); + segment_ids.emplace_back(0); + + // For Text Input. + for (int i = 0; i < all_doc_tokens.size(); i++) { + auto& doc_token = all_doc_tokens[i]; + tokens.emplace_back(doc_token); + segment_ids.emplace_back(1); + token_to_orig_map_[tokens.size()] = token_to_orig_index[i]; + } + + // For ending mark. + tokens.emplace_back("[SEP]"); + segment_ids.emplace_back(1); + + std::vector<int> input_ids(tokens.size()); + input_ids.reserve(kMaxSeqLen); + // Convert tokens back into ids + for (int i = 0; i < tokens.size(); i++) { + auto& token = tokens[i]; + tokenizer_->LookupId(token, &input_ids[i]); + } + + std::vector<int> input_mask; + input_mask.reserve(kMaxSeqLen); + input_mask.insert(input_mask.end(), tokens.size(), 1); + + int zeros_to_pad = kMaxSeqLen - input_ids.size(); + input_ids.insert(input_ids.end(), zeros_to_pad, 0); + input_mask.insert(input_mask.end(), zeros_to_pad, 0); + segment_ids.insert(segment_ids.end(), zeros_to_pad, 0); + + // input_ids INT32[1, 384] + PopulateTensor(input_ids, ids_tensor); + // input_mask INT32[1, 384] + PopulateTensor(input_mask, mask_tensor); + // segment_ids INT32[1, 384] + PopulateTensor(segment_ids, segment_ids_tensor); + + return absl::OkStatus(); +} + +StatusOr<std::vector<QaAnswer>> BertQuestionAnswerer::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& /*lowercased_context*/, + const std::string& /*lowercased_query*/) { + auto* output_tensor_metadatas = + GetMetadataExtractor()->GetOutputTensorMetadata(); + + const TfLiteTensor* end_logits_tensor = + output_tensor_metadatas + ? FindTensorByName(output_tensors, output_tensor_metadatas, + kEndLogitsTensorName) + : output_tensors[0]; + const TfLiteTensor* start_logits_tensor = + output_tensor_metadatas + ? FindTensorByName(output_tensors, output_tensor_metadatas, + kStartLogitsTensorName) + : output_tensors[1]; + + std::vector<float> end_logits; + std::vector<float> start_logits; + + // end_logits FLOAT[1, 384] + PopulateVector(end_logits_tensor, &end_logits); + // start_logits FLOAT[1, 384] + PopulateVector(start_logits_tensor, &start_logits); + + auto start_indices = ReverseSortIndices(start_logits); + auto end_indices = ReverseSortIndices(end_logits); + + std::vector<QaAnswer::Pos> orig_results; + for (int start_index = 0; start_index < kPredictAnsNum; start_index++) { + for (int end_index = 0; end_index < kPredictAnsNum; end_index++) { + int start = start_indices[start_index]; + int end = end_indices[end_index]; + + if (!token_to_orig_map_.contains(start + kOutputOffset) || + !token_to_orig_map_.contains(end + kOutputOffset) || end < start || + (end - start + 1) > kMaxAnsLen) { + continue; + } + orig_results.emplace_back( + QaAnswer::Pos(start, end, start_logits[start] + end_logits[end])); + } + } + + std::sort(orig_results.begin(), orig_results.end()); + + std::vector<QaAnswer> answers; + for (int i = 0; i < orig_results.size() && i < kPredictAnsNum; i++) { + auto orig_pos = orig_results[i]; + answers.emplace_back( + orig_pos.start > 0 ? ConvertIndexToString(orig_pos.start, orig_pos.end) + : "", + orig_pos); + } + + return answers; +} + +std::string BertQuestionAnswerer::ConvertIndexToString(int start, int end) { + int start_index = token_to_orig_map_[start + kOutputOffset]; + int end_index = token_to_orig_map_[end + kOutputOffset]; + + return absl::StrJoin(orig_tokens_.begin() + start_index, + orig_tokens_.begin() + end_index + 1, " "); +} + +absl::Status BertQuestionAnswerer::InitializeFromMetadata() { + const ProcessUnit* tokenizer_process_unit = + GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex); + if (tokenizer_process_unit == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "No input process unit found from metadata.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + ASSIGN_OR_RETURN(tokenizer_, + CreateTokenizerFromProcessUnit(tokenizer_process_unit, + GetMetadataExtractor())); + return absl::OkStatus(); +} + +void BertQuestionAnswerer::InitializeBertTokenizer( + const std::string& path_to_vocab) { + tokenizer_ = absl::make_unique<BertTokenizer>(path_to_vocab); +} + +void BertQuestionAnswerer::InitializeBertTokenizerFromBinary( + const char* vocab_buffer_data, size_t vocab_buffer_size) { + tokenizer_ = + absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size); +} + +void BertQuestionAnswerer::InitializeSentencepieceTokenizer( + const std::string& path_to_spmodel) { + tokenizer_ = absl::make_unique<SentencePieceTokenizer>(path_to_spmodel); +} + +void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary( + const char* spmodel_buffer_data, size_t spmodel_buffer_size) { + tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data, + spmodel_buffer_size); +} + +} // namespace qa +} // namespace text +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h new file mode 100644 index 00000000..4c65dc00 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h @@ -0,0 +1,170 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" + +namespace tflite { +namespace task { +namespace text { +namespace qa { + +// BertQA task API, performs tokenization for models (BERT, Albert, etc.) in +// preprocess and returns most possible answers. +// +// In particular, the branch of BERT models use WordPiece tokenizer, and the +// branch of Albert models use SentencePiece tokenizer, respectively. +// +// Factory methods: +// CreateFromFile(path_to_model_with_metadata) +// CreateFromBuffer(model_with_metadata_buffer_data, +// model_with_metadata_buffer_size) +// CreateFromFd(file_descriptor_to_model_with_metadata) +// Generic API to create the QuestionAnswerer for bert models with metadata +// populated. The API expects a Bert based TFLite model with metadata +// containing the following information: +// - input_process_units for Wordpiece/Sentencepiece Tokenizer. Wordpiece +// Tokenizer can be used for a MobileBert[0] model, Sentencepiece +// Tokenizer Tokenizer can be used for an Albert[1] model +// - 3 input tensors with names "ids", "mask" and "segment_ids" +// - 2 output tensors with names "end_logits" and "start_logits" +// [0]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 +// [1]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 +// +// CreateBertQuestionAnswererFromFile(path_to_model, path_to_vocab) +// Creates a BertQuestionAnswerer from TFLite model file and vocab file for +// WordPiece tokenizer. Used in C++ environment. +// One suitable model is: +// https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 +// +// CreateBertQuestionAnswererFromBuffer(model_buffer_data, model_buffer_size, +// vocab_buffer_data, vocab_buffer_size) +// Creates a BertQuestionAnswerer from TFLite model buffer and vocab file +// buffer for WordPiece tokenizer. Used in Jave (JNI) environment. +// +// CreateAlbertQuestionAnswererFromFile(path_to_model, path_to_spmodel) +// Creates an AlbertQuestionAnswerer from TFLite model file and +// SentencePiece model file. Used in C++ environment. +// One suitable model is: +// https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 +// +// CreateAlbertQuestionAnswererFromBuffer(model_buffer_data, +// model_buffer_size, +// spmodel_buffer_data, +// spmodel_buffer_size) +// Creates an AlbertQuestionAnswerer from TFLite model file buffer and +// SentencePiece model file buffer. Used in Jave (JNI) environment. +// + +class BertQuestionAnswerer : public QuestionAnswerer { + public: + // TODO(b/150904655): add support to parameterize. + static constexpr int kMaxQueryLen = 64; + static constexpr int kMaxSeqLen = 384; + static constexpr int kPredictAnsNum = 5; + static constexpr int kMaxAnsLen = 32; + // TODO(b/151954803): clarify the offset usage + static constexpr int kOutputOffset = 1; + static constexpr int kNumLiteThreads = 4; + static constexpr bool kUseLowerCase = true; + + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateFromFile(const std::string& path_to_model_with_metadata); + + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateFromBuffer(const char* model_with_metadata_buffer_data, + size_t model_with_metadata_buffer_size); + + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateFromFd(int fd); + + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateBertQuestionAnswererFromFile(const std::string& path_to_model, + const std::string& path_to_vocab); + + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateBertQuestionAnswererFromBuffer(const char* model_buffer_data, + size_t model_buffer_size, + const char* vocab_buffer_data, + size_t vocab_buffer_size); + + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateAlbertQuestionAnswererFromFile(const std::string& path_to_model, + const std::string& path_to_spmodel); + + static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> + CreateAlbertQuestionAnswererFromBuffer(const char* model_buffer_data, + size_t model_buffer_size, + const char* spmodel_buffer_data, + size_t spmodel_buffer_size); + + explicit BertQuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine) + : QuestionAnswerer(std::move(engine)) {} + + // Answers question based on the context. Could be empty if no answer was + // found from the given context. + std::vector<QaAnswer> Answer(const std::string& context, + const std::string& question) override; + + private: + absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const std::string& lowercased_context, + const std::string& lowercased_query) override; + + tflite::support::StatusOr<std::vector<QaAnswer>> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const std::string& lowercased_context, + const std::string& lowercased_query) override; + + // Initialize API with a BertTokenizer from the vocabulary file. + void InitializeBertTokenizer(const std::string& path_to_vocab); + // Initialize API with a BertTokenizer from the vocabulary buffer. + void InitializeBertTokenizerFromBinary(const char* vocab_buffer_data, + size_t vocab_buffer_size); + + // Initialize API with a SentencepieceTokenizer from the model file. + void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel); + // Initialize API with a SentencepieceTokenizer from the model buffer. + void InitializeSentencepieceTokenizerFromBinary( + const char* spmodel_buffer_data, size_t spmodel_buffer_size); + + // Initialize the API with the tokenizer set in the metadata. + absl::Status InitializeFromMetadata(); + + std::string ConvertIndexToString(int start, int end); + + std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_; + // Maps index of input token to index of untokenized word from original input. + absl::flat_hash_map<size_t, size_t> token_to_orig_map_; + // Original tokens of context. + std::vector<std::string> orig_tokens_; +}; + +} // namespace qa +} // namespace text +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_ diff --git a/tensorflow_lite_support/cc/task/text/qa/question_answerer.h b/tensorflow_lite_support/cc/task/text/qa/question_answerer.h new file mode 100644 index 00000000..f46a40c2 --- /dev/null +++ b/tensorflow_lite_support/cc/task/text/qa/question_answerer.h @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_ + +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +namespace tflite { +namespace task { +namespace text { +namespace qa { + +// Struct for the Answer to QuestionAnswerer. +struct QaAnswer { + // struct to represent the logit and offset of the answer related to context. + struct Pos { + Pos(int arg_start, int arg_end, float arg_logit) + : start(arg_start), end(arg_end), logit(arg_logit) {} + int start, end; + float logit; + bool operator<(const Pos& rhs) const { return rhs.logit < logit; } + }; + + QaAnswer(std::string arg_text, Pos arg_pos) + : text(std::move(arg_text)), pos(arg_pos) {} + std::string text; + Pos pos; +}; + +// Interface for an Question-Answer API. +class QuestionAnswerer + : public core::BaseTaskApi<std::vector<QaAnswer>, const std::string&, + const std::string&> { + public: + explicit QuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine) + : BaseTaskApi(std::move(engine)) {} + + virtual std::vector<QaAnswer> Answer(const std::string& context, + const std::string& question) = 0; +}; + +} // namespace qa +} // namespace text +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_QUESTION_ANSWERER_H_ diff --git a/tensorflow_lite_support/cc/task/vision/BUILD b/tensorflow_lite_support/cc/task/vision/BUILD new file mode 100644 index 00000000..d426486f --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/BUILD @@ -0,0 +1,108 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "object_detector", + srcs = ["object_detector.cc"], + hdrs = ["object_detector.h"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/core:label_map_item", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api", + ], +) + +cc_library( + name = "image_classifier", + srcs = ["image_classifier.cc"], + hdrs = ["image_classifier.h"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", + "//tensorflow_lite_support/cc/task/vision/core:classification_head", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/core:label_map_item", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/task/vision/utils:score_calibration", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api", + ], +) + +cc_library( + name = "image_segmenter", + srcs = ["image_segmenter.cc"], + hdrs = ["image_segmenter.h"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core:task_api_factory", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision/core:base_vision_task_api", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/core:label_map_item", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/core/api", + ], +) diff --git a/tensorflow_lite_support/cc/task/vision/core/BUILD b/tensorflow_lite_support/cc/task/vision/core/BUILD new file mode 100644 index 00000000..1df86cb9 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/core/BUILD @@ -0,0 +1,81 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(srcs = ["base_vision_task_api.h"]) + +cc_library( + name = "base_vision_task_api", + hdrs = [ + "base_vision_task_api.h", + ], + deps = [ + ":frame_buffer", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/task/core:base_task_api", + "//tensorflow_lite_support/cc/task/core:task_utils", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/task/vision/utils:image_tensor_specs", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + "@org_tensorflow//tensorflow/lite/c:common", + ], +) + +cc_library( + name = "frame_buffer", + srcs = ["frame_buffer.cc"], + hdrs = ["frame_buffer.h"], + deps = [ + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:any", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "label_map_item", + srcs = ["label_map_item.cc"], + hdrs = ["label_map_item.h"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "classification_head", + srcs = ["classification_head.cc"], + hdrs = ["classification_head.h"], + deps = [ + ":label_map_item", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision/utils:score_calibration", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h b/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h new file mode 100644 index 00000000..feb6b4a1 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h @@ -0,0 +1,270 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_ + +#include <array> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/time/clock.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/base_task_api.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace vision { + +// Base class providing common logic for vision models. +template <class OutputType> +class BaseVisionTaskApi + : public tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&, + const BoundingBox&> { + public: + explicit BaseVisionTaskApi(std::unique_ptr<core::TfLiteEngine> engine) + : tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&, + const BoundingBox&>(std::move(engine)) { + } + // BaseVisionTaskApi is neither copyable nor movable. + BaseVisionTaskApi(const BaseVisionTaskApi&) = delete; + BaseVisionTaskApi& operator=(const BaseVisionTaskApi&) = delete; + + // Number of bytes required for 8-bit per pixel RGB color space. + static constexpr int kRgbPixelBytes = 3; + + // Sets the ProcessEngine used for image pre-processing. Must be called before + // any inference is performed. Can be called between inferences to override + // the current process engine. + void SetProcessEngine(const FrameBufferUtils::ProcessEngine& process_engine) { + frame_buffer_utils_ = FrameBufferUtils::Create(process_engine); + } + + protected: + using tflite::task::core::BaseTaskApi<OutputType, const FrameBuffer&, + const BoundingBox&>::engine_; + + // Checks input tensor and metadata (if any) are valid, or return an error + // otherwise. This must be called once at initialization time, before running + // inference, as it is a prerequisite for `Preprocess`. + // Note: the underlying interpreter and metadata extractor are assumed to be + // already successfully initialized before calling this method. + virtual absl::Status CheckAndSetInputs() { + ASSIGN_OR_RETURN( + ImageTensorSpecs input_specs, + BuildInputImageTensorSpecs(*engine_->interpreter(), + *engine_->metadata_extractor())); + + if (input_specs.color_space != tflite::ColorSpaceType_RGB) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kUnimplemented, + "BaseVisionTaskApi only supports RGB color space for now."); + } + + input_specs_ = absl::make_unique<ImageTensorSpecs>(input_specs); + + return absl::OkStatus(); + } + + // Performs image preprocessing on the input frame buffer over the region of + // interest so that it fits model requirements (e.g. upright 224x224 RGB) and + // populate the corresponding input tensor. This is performed by (in this + // order): + // - cropping the frame buffer to the region of interest (which, in most + // cases, just covers the entire input image), + // - resizing it (with bilinear interpolation, aspect-ratio *not* preserved) + // to the dimensions of the model input tensor, + // - converting it to the colorspace of the input tensor (i.e. RGB, which is + // the only supported colorspace for now), + // - rotating it according to its `Orientation` so that inference is performed + // on an "upright" image. + // + // IMPORTANT: as a consequence of cropping occurring first, the provided + // region of interest is expressed in the unrotated frame of reference + // coordinates system, i.e. in `[0, frame_buffer.width) x [0, + // frame_buffer.height)`, which are the dimensions of the underlying + // `frame_buffer` data before any `Orientation` flag gets applied. Also, the + // region of interest is not clamped, so this method will return a non-ok + // status if the region is out of these bounds. + absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, + const FrameBuffer& frame_buffer, + const BoundingBox& roi) override { + if (input_specs_ == nullptr) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Uninitialized input tensor specs: CheckAndSetInputs must be called " + "at initialization time."); + } + + if (frame_buffer_utils_ == nullptr) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Uninitialized frame buffer utils: SetProcessEngine must be called " + "at initialization time."); + } + + if (input_tensors.size() != 1) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, "A single input tensor is expected."); + } + + // Input data to be normalized (if needed) and used for inference. In most + // cases, this is the result of image preprocessing. In case no image + // preprocessing is needed (see below), this points to the input frame + // buffer raw data. + const uint8* input_data; + size_t input_data_byte_size; + + // Optional buffers in case image preprocessing is needed. + std::unique_ptr<FrameBuffer> preprocessed_frame_buffer; + std::vector<uint8> preprocessed_data; + + if (IsImagePreprocessingNeeded(frame_buffer, roi)) { + // Preprocess input image to fit model requirements. + // For now RGB is the only color space supported, which is ensured by + // `CheckAndSetInputs`. + FrameBuffer::Dimension to_buffer_dimension = {input_specs_->image_width, + input_specs_->image_height}; + input_data_byte_size = + GetBufferByteSize(to_buffer_dimension, FrameBuffer::Format::kRGB); + preprocessed_data.resize(input_data_byte_size / sizeof(uint8), 0); + input_data = preprocessed_data.data(); + + FrameBuffer::Plane preprocessed_plane = { + /*buffer=*/preprocessed_data.data(), + /*stride=*/{input_specs_->image_width * kRgbPixelBytes, + kRgbPixelBytes}}; + preprocessed_frame_buffer = FrameBuffer::Create( + {preprocessed_plane}, to_buffer_dimension, FrameBuffer::Format::kRGB, + FrameBuffer::Orientation::kTopLeft); + + RETURN_IF_ERROR(frame_buffer_utils_->Preprocess( + frame_buffer, roi, preprocessed_frame_buffer.get())); + } else { + // Input frame buffer already targets model requirements: skip image + // preprocessing. For RGB, the data is always stored in a single plane. + input_data = frame_buffer.plane(0).buffer; + input_data_byte_size = frame_buffer.plane(0).stride.row_stride_bytes * + frame_buffer.dimension().height; + } + + // Then normalize pixel data (if needed) and populate the input tensor. + switch (input_specs_->tensor_type) { + case kTfLiteUInt8: + if (input_tensors[0]->bytes != input_data_byte_size) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Size mismatch or unsupported padding bytes between pixel data " + "and input tensor."); + } + // No normalization required: directly populate data. + tflite::task::core::PopulateTensor( + input_data, input_data_byte_size / sizeof(uint8), input_tensors[0]); + break; + case kTfLiteFloat32: { + if (input_tensors[0]->bytes / sizeof(float) != + input_data_byte_size / sizeof(uint8)) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Size mismatch or unsupported padding bytes between pixel data " + "and input tensor."); + } + // Normalize and populate. + float* normalized_input_data = + tflite::task::core::AssertAndReturnTypedTensor<float>( + input_tensors[0]); + const tflite::task::vision::NormalizationOptions& + normalization_options = input_specs_->normalization_options.value(); + if (normalization_options.num_values == 1) { + float mean_value = normalization_options.mean_values[0]; + float inv_std_value = (1.0f / normalization_options.std_values[0]); + for (int i = 0; i < input_data_byte_size / sizeof(uint8); + i++, input_data++, normalized_input_data++) { + *normalized_input_data = + inv_std_value * (static_cast<float>(*input_data) - mean_value); + } + } else { + std::array<float, 3> inv_std_values = { + 1.0f / normalization_options.std_values[0], + 1.0f / normalization_options.std_values[1], + 1.0f / normalization_options.std_values[2]}; + for (int i = 0; i < input_data_byte_size / sizeof(uint8); + i++, input_data++, normalized_input_data++) { + *normalized_input_data = inv_std_values[i % 3] * + (static_cast<float>(*input_data) - + normalization_options.mean_values[i % 3]); + } + } + break; + } + case kTfLiteInt8: + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kUnimplemented, + "kTfLiteInt8 input type is not implemented yet."); + default: + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, "Unexpected input tensor type."); + } + + return absl::OkStatus(); + } + + // Utils for input image preprocessing (resizing, colorspace conversion, etc). + std::unique_ptr<FrameBufferUtils> frame_buffer_utils_; + + // Parameters related to the input tensor which represents an image. + std::unique_ptr<ImageTensorSpecs> input_specs_; + + private: + // Returns false if image preprocessing could be skipped, true otherwise. + bool IsImagePreprocessingNeeded(const FrameBuffer& frame_buffer, + const BoundingBox& roi) { + // Is crop required? + if (roi.origin_x() != 0 || roi.origin_y() != 0 || + roi.width() != frame_buffer.dimension().width || + roi.height() != frame_buffer.dimension().height) { + return true; + } + + // Are image transformations required? + if (frame_buffer.orientation() != FrameBuffer::Orientation::kTopLeft || + frame_buffer.format() != FrameBuffer::Format::kRGB || + frame_buffer.dimension().width != input_specs_->image_width || + frame_buffer.dimension().height != input_specs_->image_height) { + return true; + } + + return false; + } +}; + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_BASE_VISION_TASK_API_H_ diff --git a/tensorflow_lite_support/cc/task/vision/core/classification_head.cc b/tensorflow_lite_support/cc/task/vision/core/classification_head.cc new file mode 100644 index 00000000..962cb34b --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/core/classification_head.cc @@ -0,0 +1,114 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/vision/core/classification_head.h" + +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace vision { + +using ::absl::StatusCode; +using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +StatusOr<ClassificationHead> BuildClassificationHead( + const tflite::metadata::ModelMetadataExtractor& metadata_extractor, + const tflite::TensorMetadata& output_tensor_metadata, + absl::string_view display_names_locale) { + ClassificationHead head; + if (output_tensor_metadata.name() != nullptr) { + head.name = output_tensor_metadata.name()->str(); + } + + // Build label map, if present. + const std::string labels_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + output_tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_LABELS); + if (!labels_filename.empty()) { + ASSIGN_OR_RETURN(absl::string_view labels_file, + metadata_extractor.GetAssociatedFile(labels_filename)); + const std::string display_names_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + output_tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_LABELS, + display_names_locale); + absl::string_view display_names_file; + if (!display_names_filename.empty()) { + ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile( + display_names_filename)); + } + ASSIGN_OR_RETURN(head.label_map_items, + BuildLabelMapFromFiles(labels_file, display_names_file)); + } + + // Set score threshold, if present. + ASSIGN_OR_RETURN(const tflite::ProcessUnit* score_thresholding_process_unit, + ModelMetadataExtractor::FindFirstProcessUnit( + output_tensor_metadata, + tflite::ProcessUnitOptions_ScoreThresholdingOptions)); + if (score_thresholding_process_unit != nullptr) { + head.score_threshold = + score_thresholding_process_unit->options_as_ScoreThresholdingOptions() + ->global_score_threshold(); + } + + // Build score calibration parameters, if present. + ASSIGN_OR_RETURN(const tflite::ProcessUnit* score_calibration_process_unit, + ModelMetadataExtractor::FindFirstProcessUnit( + output_tensor_metadata, + tflite::ProcessUnitOptions_ScoreCalibrationOptions)); + if (score_calibration_process_unit != nullptr) { + if (labels_filename.empty()) { + return CreateStatusWithPayload( + StatusCode::kNotFound, + "Using ScoreCalibrationOptions requires a label map to be provided " + "as TENSOR_AXIS_LABELS associated file.", + TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError); + } + const std::string score_calibration_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + output_tensor_metadata, + tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); + if (score_calibration_filename.empty()) { + return CreateStatusWithPayload( + StatusCode::kNotFound, + "Found ScoreCalibrationOptions but missing required associated " + "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.", + TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError); + } + ASSIGN_OR_RETURN( + absl::string_view score_calibration_file, + metadata_extractor.GetAssociatedFile(score_calibration_filename)); + ASSIGN_OR_RETURN(SigmoidCalibrationParameters sigmoid_params, + BuildSigmoidCalibrationParams( + *score_calibration_process_unit + ->options_as_ScoreCalibrationOptions(), + score_calibration_file, head.label_map_items)); + head.calibration_params = sigmoid_params; + } + + return head; +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/core/classification_head.h b/tensorflow_lite_support/cc/task/vision/core/classification_head.h new file mode 100644 index 00000000..07cd8b9b --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/core/classification_head.h @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_ + +#include <string> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" +#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace vision { + +// A single classifier head for an image classifier model, associated with a +// corresponding output tensor. +struct ClassificationHead { + ClassificationHead() : score_threshold(0) {} + + explicit ClassificationHead( + const std::vector<tflite::task::vision::LabelMapItem>&& label_map_items) + : label_map_items(label_map_items), score_threshold(0) {} + + // An optional name that usually indicates what this set of classes represent, + // e.g. "flowers". + std::string name; + // The label map representing the list of supported classes, aka labels. + // + // This must be in direct correspondence with the associated output tensor, + // i.e.: + // + // - The number of classes must match with the dimension of the corresponding + // output tensor, + // - The i-th item in the label map is assumed to correspond to the i-th + // output value in the output tensor. + // + // This requires to put in place dedicated sanity checks before running + // inference. + std::vector<tflite::task::vision::LabelMapItem> label_map_items; + // Recommended score threshold typically in [0,1[. Classification results with + // a score below this value are considered low-confidence and should be + // rejected from returned results. + float score_threshold; + // Optional score calibration parameters (one set of parameters per class in + // the label map). This is primarily meant for multi-label classifiers made of + // independent sigmoids. + // + // Such parameters are usually tuned so that calibrated scores can be compared + // to a default threshold common to all classes to achieve a given amount of + // precision. + // + // Example: 60% precision for threshold = 0.5. + absl::optional<tflite::task::vision::SigmoidCalibrationParameters> + calibration_params; +}; + +// Builds a classification head using the provided metadata extractor, for the +// given output tensor metadata. Returns an error in case the head cannot be +// built (e.g. missing associated file for score calibration parameters). +// +// Optionally it is possible to specify which locale should be used (e.g. "en") +// to fill the label map display names, if any, and provided the corresponding +// associated file is present in the metadata. If no locale is specified, or if +// there is no associated file for the provided locale, display names are just +// left empty and no error is returned. +// +// E.g. (metatada displayed in JSON format below): +// +// ... +// "associated_files": [ +// { +// "name": "labels.txt", +// "type": "TENSOR_AXIS_LABELS" +// }, +// { +// "name": "labels-en.txt", +// "type": "TENSOR_AXIS_LABELS", +// "locale": "en" +// }, +// ... +// +// See metadata schema TENSOR_AXIS_LABELS for more details. +tflite::support::StatusOr<ClassificationHead> BuildClassificationHead( + const tflite::metadata::ModelMetadataExtractor& metadata_extractor, + const tflite::TensorMetadata& output_tensor_metadata, + absl::string_view display_names_locale = absl::string_view()); + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_CLASSIFICATION_HEAD_ITEM_H_ diff --git a/tensorflow_lite_support/cc/task/vision/core/frame_buffer.cc b/tensorflow_lite_support/cc/task/vision/core/frame_buffer.cc new file mode 100644 index 00000000..02658cd9 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/core/frame_buffer.cc @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" + +namespace tflite { +namespace task { +namespace vision { + +using ::tflite::support::StatusOr; + +namespace { + +// Returns whether the input `format` is a supported YUV format. +bool IsSupportedYuvFormat(FrameBuffer::Format format) { + return format == FrameBuffer::Format::kNV21 || + format == FrameBuffer::Format::kNV12 || + format == FrameBuffer::Format::kYV12 || + format == FrameBuffer::Format::kYV21; +} + +// Returns supported 1-plane FrameBuffer in YuvData structure. +StatusOr<FrameBuffer::YuvData> GetYuvDataFromOnePlaneFrameBuffer( + const FrameBuffer& source) { + if (!IsSupportedYuvFormat(source.format())) { + return absl::InvalidArgumentError( + "The source FrameBuffer format is not part of YUV420 family."); + } + + FrameBuffer::YuvData result; + const int y_buffer_size = + source.plane(0).stride.row_stride_bytes * source.dimension().height; + const int uv_buffer_size = + ((source.plane(0).stride.row_stride_bytes + 1) / 2) * + ((source.dimension().height + 1) / 2); + result.y_buffer = source.plane(0).buffer; + result.y_row_stride = source.plane(0).stride.row_stride_bytes; + result.uv_row_stride = result.y_row_stride; + + if (source.format() == FrameBuffer::Format::kNV21) { + result.v_buffer = result.y_buffer + y_buffer_size; + result.u_buffer = result.v_buffer + 1; + result.uv_pixel_stride = 2; + // If y_row_stride equals to the frame width and is an odd value, + // uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride. + if (result.y_row_stride == source.dimension().width && + result.y_row_stride % 2 == 1) { + result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2; + } + } else if (source.format() == FrameBuffer::Format::kNV12) { + result.u_buffer = result.y_buffer + y_buffer_size; + result.v_buffer = result.u_buffer + 1; + result.uv_pixel_stride = 2; + // If y_row_stride equals to the frame width and is an odd value, + // uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride. + if (result.y_row_stride == source.dimension().width && + result.y_row_stride % 2 == 1) { + result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2; + } + } else if (source.format() == FrameBuffer::Format::kYV21) { + result.u_buffer = result.y_buffer + y_buffer_size; + result.v_buffer = result.u_buffer + uv_buffer_size; + result.uv_pixel_stride = 1; + result.uv_row_stride = (result.y_row_stride + 1) / 2; + } else if (source.format() == FrameBuffer::Format::kYV12) { + result.v_buffer = result.y_buffer + y_buffer_size; + result.u_buffer = result.v_buffer + uv_buffer_size; + result.uv_pixel_stride = 1; + result.uv_row_stride = (result.y_row_stride + 1) / 2; + } + return result; +} + +// Returns supported 2-plane FrameBuffer in YuvData structure. +StatusOr<FrameBuffer::YuvData> GetYuvDataFromTwoPlaneFrameBuffer( + const FrameBuffer& source) { + if (source.format() != FrameBuffer::Format::kNV12 && + source.format() != FrameBuffer::Format::kNV21) { + return absl::InvalidArgumentError("Unsupported YUV planar format."); + } + + FrameBuffer::YuvData result; + // Y plane + result.y_buffer = source.plane(0).buffer; + // All plane strides + result.y_row_stride = source.plane(0).stride.row_stride_bytes; + result.uv_row_stride = source.plane(1).stride.row_stride_bytes; + result.uv_pixel_stride = 2; + + if (source.format() == FrameBuffer::Format::kNV12) { + // Y and UV interleaved format + result.u_buffer = source.plane(1).buffer; + result.v_buffer = result.u_buffer + 1; + } else { + // Y and VU interleaved format + result.v_buffer = source.plane(1).buffer; + result.u_buffer = result.v_buffer + 1; + } + return result; +} + +// Returns supported 3-plane FrameBuffer in YuvData structure. Note that NV21 +// and NV12 are included in the supported Yuv formats. Technically, NV21 and +// NV12 should not be described by the 3-plane format. Historically, NV21 is +// used loosely such that it can also be used to describe YV21 format. For +// backwards compatibility, FrameBuffer supports NV21/NV12 with 3-plane format +// but such usage is discouraged +StatusOr<FrameBuffer::YuvData> GetYuvDataFromThreePlaneFrameBuffer( + const FrameBuffer& source) { + if (!IsSupportedYuvFormat(source.format())) { + return absl::InvalidArgumentError( + "The source FrameBuffer format is not part of YUV420 family."); + } + + if (source.plane(1).stride.row_stride_bytes != + source.plane(2).stride.row_stride_bytes || + source.plane(1).stride.pixel_stride_bytes != + source.plane(2).stride.pixel_stride_bytes) { + return absl::InternalError("Unsupported YUV planar format."); + } + FrameBuffer::YuvData result; + if (source.format() == FrameBuffer::Format::kNV21 || + source.format() == FrameBuffer::Format::kYV12) { + // Y follow by VU order. The VU chroma planes can be interleaved or + // planar. + result.y_buffer = source.plane(0).buffer; + result.v_buffer = source.plane(1).buffer; + result.u_buffer = source.plane(2).buffer; + result.y_row_stride = source.plane(0).stride.row_stride_bytes; + result.uv_row_stride = source.plane(1).stride.row_stride_bytes; + result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes; + } else { + // Y follow by UV order. The UV chroma planes can be interleaved or + // planar. + result.y_buffer = source.plane(0).buffer; + result.u_buffer = source.plane(1).buffer; + result.v_buffer = source.plane(2).buffer; + result.y_row_stride = source.plane(0).stride.row_stride_bytes; + result.uv_row_stride = source.plane(1).stride.row_stride_bytes; + result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes; + } + return result; +} + +} // namespace + +StatusOr<FrameBuffer::YuvData> FrameBuffer::GetYuvDataFromFrameBuffer( + const FrameBuffer& source) { + if (!IsSupportedYuvFormat(source.format())) { + return absl::InvalidArgumentError( + "The source FrameBuffer format is not part of YUV420 family."); + } + + if (source.plane_count() == 1) { + return GetYuvDataFromOnePlaneFrameBuffer(source); + } else if (source.plane_count() == 2) { + return GetYuvDataFromTwoPlaneFrameBuffer(source); + } else if (source.plane_count() == 3) { + return GetYuvDataFromThreePlaneFrameBuffer(source); + } + return absl::InvalidArgumentError( + "The source FrameBuffer must be consisted by 1, 2, or 3 planes"); +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h b/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h new file mode 100644 index 00000000..31589f38 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/core/frame_buffer.h @@ -0,0 +1,296 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_FRAME_BUFFER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_FRAME_BUFFER_H_ + +#include <map> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/any.h" +#include "absl/types/optional.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/statusor.h" + +namespace tflite { +namespace task { +namespace vision { + +// A `FrameBuffer` provides a view into the provided backing buffer (e.g. camera +// frame or still image) with buffer format information. FrameBuffer doesn't +// take ownership of the provided backing buffer. The caller is responsible to +// manage the backing buffer lifecycle for the lifetime of the FrameBuffer. +// +// FrameBuffer also provides a tagging system to allow the client of FrameBuffer +// to attach arbitrary tags to an instance. The tagging system is meant for +// small set of metadata. FrameBuffer does not use the tags in anyway. The +// uniqueness of the tag is only guarded by the uniqueness of the key. +// The tag is useful when the uniqueness of a FrameBuffer can not be determined +// by its associated metadata. For example, there are two FrameBuffer instances +// with the same metadata (size dimension, orientation, format, etc) but one is +// generated through cropping of Frame A and another is generated by resizing of +// Frame A. The client can tag one of the generated FrameBuffer to distinguish +// the difference. +// +// Examples: +// +// // Create an metadata instance with no backing buffer. +// auto buffer = FrameBuffer::Create(/*planes=*/{}, dimension, kRGBA, +// KTopLeft); +// +// // Create an RGBA instance with backing buffer on single plane. +// FrameBuffer::Plane plane = {rgba_buffer, /*stride=*/{dimension.width * 4, +// 4}}; auto buffer = FrameBuffer::Create({plane}, dimension, kRGBA, kTopLeft); +// +// // Create an YUV instance with planar backing buffer. +// FrameBuffer::Plane y_plane = {y_buffer, /*stride=*/{dimension.width , 1}}; +// FrameBuffer::Plane uv_plane = {u_buffer, /*stride=*/{dimension.width, 2}}; +// auto buffer = FrameBuffer::Create({y_plane, uv_plane}, dimension, kNV21, +// kLeftTop); +// +// // Add / retrieve tags from a FrameBuffer instance. +// buffer.InsertTag("my_special_key", 1); +// buffer.GetTag("my_special_key"); +// +class FrameBuffer { + public: + // Colorspace formats. + enum class Format { kRGBA, kRGB, kNV12, kNV21, kYV12, kYV21, kGRAY }; + + // Stride information. + struct Stride { + // The row stride in bytes. This is the distance between the start pixels of + // two consecutive rows in the image. + int row_stride_bytes; + // This is the distance between two consecutive pixel values in a row of + // pixels in bytes. It may be larger than the size of a single pixel to + // account for interleaved image data or padded formats. + int pixel_stride_bytes; + }; + + // YUV data structure. + struct YuvData { + const uint8* y_buffer; + const uint8* u_buffer; + const uint8* v_buffer; + // Y buffer row stride in bytes. + int y_row_stride; + // U/V buffer row stride in bytes. + int uv_row_stride; + // U/V pixel stride in bytes. This is the distance between two consecutive + // u/v pixel values in a row. + int uv_pixel_stride; + }; + + // FrameBuffer content orientation follows EXIF specification. The name of + // each enum value defines the position of the 0th row and the 0th column of + // the image content. See http://jpegclub.org/exif_orientation.html for + // details. + enum class Orientation { + kTopLeft = 1, + kTopRight = 2, + kBottomRight = 3, + kBottomLeft = 4, + kLeftTop = 5, + kRightTop = 6, + kRightBottom = 7, + kLeftBottom = 8 + }; + + // Plane encapsulates buffer and stride information. + struct Plane { + const uint8* buffer; + Stride stride; + }; + + // Dimension information for the whole frame or a cropped portion of it. + struct Dimension { + // The width dimension in pixel unit. + int width; + // The height dimension in pixel unit. + int height; + + bool operator==(const Dimension& other) const { + return width == other.width && height == other.height; + } + + bool operator!=(const Dimension& other) const { + return width != other.width || height != other.height; + } + + bool operator>=(const Dimension& other) const { + return width >= other.width && height >= other.height; + } + + bool operator<=(const Dimension& other) const { + return width <= other.width && height <= other.height; + } + + // Swaps width and height. + void Swap() { + using std::swap; + swap(width, height); + } + + // Returns area represented by width * height. + int Size() const { return width * height; } + }; + + // Factory method for creating a FrameBuffer object from row-major backing + // buffers. In a streaming use case (e.g continuous camera stream), the + // timestamp can be used as an ID to identify a frame. + static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes, + Dimension dimension, Format format, + Orientation orientation, + absl::Time timestamp) { + return absl::make_unique<FrameBuffer>(planes, dimension, format, + orientation, timestamp); + } + + // Factory method for creating a FrameBuffer object from row-major movable + // backing buffers. In a streaming use case (e.g continuous camera stream), + // the timestamp can be used as an ID to identify a frame. + static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes, + Dimension dimension, Format format, + Orientation orientation, + absl::Time timestamp) { + return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format, + orientation, timestamp); + } + + // Factory method for creating a FrameBuffer object from row-major backing + // buffers. By default this method set the timestamp to now. This method is + // more suitable for processing use case that does not need to re-identify + // this buffer. + static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes, + Dimension dimension, Format format, + Orientation orientation) { + return absl::make_unique<FrameBuffer>(planes, dimension, format, + orientation, absl::Now()); + } + + // Factory method for creating a FrameBuffer object from movable row-major + // backing buffers. By default this method set the timestamp to now. This + // method is more suitable for processing use case that does not need to + // re-identify this buffer. + static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes, + Dimension dimension, Format format, + Orientation orientation) { + return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format, + orientation, absl::Now()); + } + + // Returns YuvData which contains the Y, U, and V buffer and their + // stride info from the input `source` FrameBuffer which is in the YUV family + // formats (e.g NV12, NV21, YV12, and YV21). + static tflite::support::StatusOr<YuvData> GetYuvDataFromFrameBuffer( + const FrameBuffer& source); + + // Builds a FrameBuffer object from a row-major backing buffer. + // + // The FrameBuffer does not take ownership of the backing buffer. The backing + // buffer is read-only and the caller is responsible for maintaining the + // backing buffer lifecycle for the lifetime of FrameBuffer. + FrameBuffer(const std::vector<Plane>& planes, Dimension dimension, + Format format, Orientation orientation, absl::Time timestamp) + : planes_(planes), + dimension_(dimension), + format_(format), + orientation_(orientation), + timestamp_(timestamp) {} + + // Builds a FrameBuffer object from a movable row-major backing buffer. + // + // The FrameBuffer does not take ownership of the backing buffer. The backing + // buffer is read-only and the caller is responsible for maintaining the + // backing buffer lifecycle for the lifetime of FrameBuffer. + FrameBuffer(std::vector<Plane>&& planes, Dimension dimension, Format format, + Orientation orientation, absl::Time timestamp) + : planes_(std::move(planes)), + dimension_(dimension), + format_(format), + orientation_(orientation), + timestamp_(timestamp) {} + + // Returns number of planes. + const int plane_count() const { return planes_.size(); } + + // Returns plane indexed by the input `index`. + const Plane plane(int index) const { + if (index > -1 && index < planes_.size()) { + return planes_[index]; + } + return {}; + } + + // Returns the tag associated to the tag_key. + absl::any GetTag(const std::string& tag_key) const { + auto iter = tags_.find(tag_key); + if (iter != tags_.end()) { + return iter->second; + } + return absl::any(); + } + + // Inserts or updates the tags map with key value pair (tag_key, tag_value). + void InsertOrUpdateTag(const std::string& tag_key, absl::any tag_value) { + tags_[tag_key] = std::move(tag_value); + } + + // Inserts the key value pair (tag_key, tag_value) into tags map. If the + // tag_key already exists, an internal error will return. + absl::Status InsertTag(const std::string& tag_key, absl::any tag_value) { + auto iter = tags_.emplace(tag_key, tag_value); + if (iter.second) { + return absl::OkStatus(); + } + return absl::InternalError(absl::StrCat( + "tag_key already exists in tags.tag_key was not inserted: ", tag_key)); + } + + // Returns FrameBuffer dimension. + const Dimension dimension() const { return dimension_; } + + // Returns FrameBuffer format. + const Format format() const { return format_; } + + // Returns FrameBuffer orientation. + const Orientation orientation() const { return orientation_; } + + // Returns FrameBuffer timestamp. + const absl::Time timestamp() const { return timestamp_; } + + private: + std::vector<Plane> planes_; + std::map<std::string, absl::any> tags_; + Dimension dimension_; + Format format_; + Orientation orientation_; + absl::Time timestamp_; +}; + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_FRAME_BUFFER_H_ diff --git a/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc b/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc new file mode 100644 index 00000000..75b1fc60 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc @@ -0,0 +1,128 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "tensorflow_lite_support/cc/common.h" + +namespace tflite { +namespace task { +namespace vision { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( + absl::string_view labels_file, absl::string_view display_names_file) { + if (labels_file.empty()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Expected non-empty labels file.", + TfLiteSupportStatus::kInvalidArgumentError); + } + std::vector<absl::string_view> labels = absl::StrSplit(labels_file, '\n'); + // In most cases, there is an empty line (i.e. newline character) at the end + // of the file that needs to be ignored. In such a situation, StrSplit() will + // produce a vector with an empty string as final element. Also note that in + // case `labels_file` is entirely empty, StrSplit() will produce a vector with + // one single empty substring, so there's no out-of-range risk here. + if (labels[labels.size() - 1].empty()) { + labels.pop_back(); + } + + std::vector<LabelMapItem> label_map_items; + label_map_items.reserve(labels.size()); + for (int i = 0; i < labels.size(); ++i) { + label_map_items.emplace_back(LabelMapItem{.name = std::string(labels[i])}); + } + + if (!display_names_file.empty()) { + std::vector<std::string> display_names = + absl::StrSplit(display_names_file, '\n'); + // In most cases, there is an empty line (i.e. newline character) at the end + // of the file that needs to be ignored. See above. + if (display_names[display_names.size() - 1].empty()) { + display_names.pop_back(); + } + if (display_names.size() != labels.size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Mismatch between number of labels (%d) and display names (%d).", + labels.size(), display_names.size()), + TfLiteSupportStatus::kMetadataNumLabelsMismatchError); + } + for (int i = 0; i < display_names.size(); ++i) { + label_map_items[i].display_name = display_names[i]; + } + } + return label_map_items; +} + +absl::Status LabelHierarchy::InitializeFromLabelMap( + std::vector<LabelMapItem> label_map_items) { + parents_map_.clear(); + for (const LabelMapItem& label : label_map_items) { + for (const std::string& child_name : label.child_name) { + parents_map_[child_name].insert(label.name); + } + } + if (parents_map_.empty()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Input labelmap is not hierarchical: there " + "is no parent-child relationship.", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +bool LabelHierarchy::HaveAncestorDescendantRelationship( + const std::string& ancestor_name, + const std::string& descendant_name) const { + absl::flat_hash_set<std::string> ancestors; + GetAncestors(descendant_name, &ancestors); + return ancestors.contains(ancestor_name); +} + +absl::flat_hash_set<std::string> LabelHierarchy::GetParents( + const std::string& name) const { + absl::flat_hash_set<std::string> parents; + auto it = parents_map_.find(name); + if (it != parents_map_.end()) { + for (const std::string& parent_name : it->second) { + parents.insert(parent_name); + } + } + return parents; +} + +void LabelHierarchy::GetAncestors( + const std::string& name, + absl::flat_hash_set<std::string>* ancestors) const { + const absl::flat_hash_set<std::string> parents = GetParents(name); + for (const std::string& parent_name : parents) { + auto it = ancestors->insert(parent_name); + if (it.second) { + GetAncestors(parent_name, ancestors); + } + } +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/core/label_map_item.h b/tensorflow_lite_support/cc/task/vision/core/label_map_item.h new file mode 100644 index 00000000..3ac9a000 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/core/label_map_item.h @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_ + +#include <string> +#include <vector> + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow_lite_support/cc/port/statusor.h" + +namespace tflite { +namespace task { +namespace vision { + +// Structure mapping a numerical class index output to a Knowledge Graph entity +// ID or any other string label representing this class. Optionally it is +// possible to specify an additional display name (in a given language) which is +// typically used for display purposes. +struct LabelMapItem { + // E.g. name = "/m/02xwb" + std::string name; + // E.g. display_name = "Fruit" + std::string display_name; + // Optional list of children (e.g. subcategories) used to represent a + // hierarchy. + std::vector<std::string> child_name; +}; + +// Builds a label map from labels and (optional) display names file contents, +// both expected to contain one label per line. Those are typically obtained +// from TFLite Model Metadata TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS +// associated files. +// Returns an error e.g. if there's a mismatch between the number of labels and +// display names. +tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( + absl::string_view labels_file, absl::string_view display_names_file); + +// A class that represents a hierarchy of labels as specified in a label map. +// +// For example, it is useful to determine if one label is a descendant of +// another label or not. This can be used to implement labels pruning based on +// hierarchy, e.g. if both "fruit" and "banana" have been inferred by a given +// classifier model prune "fruit" from the final results as "banana" is a more +// fine-grained descendant. +class LabelHierarchy { + public: + LabelHierarchy() = default; + + // Initializes the hierarchy of labels from a given label map vector. Returns + // an error status in case of failure, typically if the input label map does + // not contain any hierarchical relations between labels. + absl::Status InitializeFromLabelMap( + std::vector<LabelMapItem> label_map_items); + + // Returns true if `descendant_name` is a descendant of `ancestor_name` in the + // hierarchy of labels. Invalid names, i.e. names which do not exist in the + // label map used at initialization time, are ignored. + bool HaveAncestorDescendantRelationship( + const std::string& ancestor_name, + const std::string& descendant_name) const; + + private: + // Retrieve and return all parent names, if any, for the input label name. + absl::flat_hash_set<std::string> GetParents(const std::string& name) const; + + // Retrieve all ancestor names, if any, for the input label name. + void GetAncestors(const std::string& name, + absl::flat_hash_set<std::string>* ancestors) const; + + // Label name (key) to parent names (value) direct mapping. + absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> + parents_map_; +}; + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_ diff --git a/tensorflow_lite_support/cc/task/vision/image_classifier.cc b/tensorflow_lite_support/cc/task/vision/image_classifier.cc new file mode 100644 index 00000000..378797b4 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/image_classifier.cc @@ -0,0 +1,572 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/image_classifier.h" + +#include "absl/algorithm/container.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/interpreter.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" +#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { + +using ::absl::StatusCode; +using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::core::AssertAndReturnTypedTensor; +using ::tflite::task::core::TaskAPIFactory; +using ::tflite::task::core::TfLiteEngine; + +// Default score value used as a fallback for classes that (1) have no score +// calibration data or (2) have a very low confident uncalibrated score, i.e. +// lower than the `min_uncalibrated_score` threshold. +// +// (1) This happens when the ScoreCalibration does not cover all the classes +// listed in the label map. This can be used to enforce the blacklisting of +// given classes so that they are never returned. +// +// (2) This is an optional threshold provided part of the calibration data. It +// is used to mitigate false alarms on some classes. +// +// In both cases, a class that gets assigned a score of -1 is never returned as +// it gets discarded by the `score_threshold` check (see post-processing logic). +constexpr float kDefaultCalibratedScore = -1.0f; + +// Calibrated scores should be in the [0, 1] range, otherwise an error is +// returned at post-processing time. +constexpr float kMinCalibratedScore = 0.0f; +constexpr float kMaxCalibratedScore = 1.0f; + +} // namespace + +/* static */ +StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::CreateFromOptions( + const ImageClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + // Copy options to ensure the ExternalFile outlives the constructed object. + auto options_copy = absl::make_unique<ImageClassifierOptions>(options); + + ASSIGN_OR_RETURN(auto image_classifier, + TaskAPIFactory::CreateFromExternalFileProto<ImageClassifier>( + &options_copy->model_file_with_metadata(), + std::move(resolver), options_copy->num_threads())); + + RETURN_IF_ERROR(image_classifier->Init(std::move(options_copy))); + + return image_classifier; +} + +/* static */ +absl::Status ImageClassifier::SanityCheckOptions( + const ImageClassifierOptions& options) { + if (!options.has_model_file_with_metadata()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Missing mandatory `model_file_with_metadata` field", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options.max_results() == 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Invalid `max_results` option: value must be != 0", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options.score_threshold() < 0 || options.score_threshold() >= 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "`score_threshold` out of range: %f. Valid range is [0,1[.", + options.score_threshold()), + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options.class_name_whitelist_size() > 0 && + options.class_name_blacklist_size() > 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "`class_name_whitelist` and `class_name_blacklist` are mutually " + "exclusive options.", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options.num_threads() == 0 || options.num_threads() < -1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "`num_threads` must be greater than 0 or equal to -1.", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +absl::Status ImageClassifier::Init( + std::unique_ptr<ImageClassifierOptions> options) { + // Set options. + options_ = std::move(options); + + // Perform pre-initialization actions (by default, sets the process engine for + // image pre-processing to kLibyuv as a sane default). + RETURN_IF_ERROR(PreInit()); + + // Sanity check and set inputs and outputs. + RETURN_IF_ERROR(CheckAndSetInputs()); + RETURN_IF_ERROR(CheckAndSetOutputs()); + + // Initialize class whitelisting/blacklisting, if any. + RETURN_IF_ERROR(CheckAndSetClassNameSet()); + + // Perform final initialization (by default, initialize score calibration + // parameters, if any). + RETURN_IF_ERROR(PostInit()); + + return absl::OkStatus(); +} + +absl::Status ImageClassifier::PreInit() { + SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv); + return absl::OkStatus(); +} + +absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); } + +absl::Status ImageClassifier::CheckAndSetOutputs() { + num_outputs_ = TfLiteEngine::OutputCount(engine_->interpreter()); + + // Perform sanity checks and extract metadata. + const ModelMetadataExtractor* metadata_extractor = + engine_->metadata_extractor(); + + const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* + output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata(); + + // Loop over output tensors metadata, if any. + // Note: models with no output tensor metadata at all are supported. + if (output_tensor_metadata != nullptr) { + int num_output_tensors = output_tensor_metadata->size(); + + if (num_outputs_ != num_output_tensors) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of output tensors (%d) and " + "output tensors " + "metadata (%d).", + num_outputs_, num_output_tensors), + TfLiteSupportStatus::kMetadataInconsistencyError); + } + + for (int i = 0; i < num_output_tensors; ++i) { + const tflite::TensorMetadata* output_tensor = + output_tensor_metadata->Get(i); + + ASSIGN_OR_RETURN( + ClassificationHead head, + BuildClassificationHead(*metadata_extractor, *output_tensor, + options_->display_names_locale())); + + classification_heads_.emplace_back(std::move(head)); + } + } + + // If classifier heads are not set, build default ones based on model + // introspection. This happens if a model with partial or no metadata was + // provided through the `model_file_with_metadata` options field. + if (classification_heads_.empty()) { + classification_heads_.reserve(num_outputs_); + for (int output_index = 0; output_index < num_outputs_; ++output_index) { + classification_heads_.emplace_back(ClassificationHead{}); + } + } + + if (num_outputs_ != classification_heads_.size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Got %d classifier head(s), expected %d according to " + "the label map.", + num_outputs_, classification_heads_.size()), + TfLiteSupportStatus::kMetadataInconsistencyError); + } + + int num_quantized_outputs = 0; + for (int i = 0; i < num_outputs_; ++i) { + const TfLiteTensor* output_tensor = + TfLiteEngine::GetOutput(engine_->interpreter(), i); + const int num_dimensions = output_tensor->dims->size; + if (num_dimensions == 4) { + if (output_tensor->dims->data[1] != 1 || + output_tensor->dims->data[2] != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Unexpected WxH sizes for output index %d: got " + "%dx%d, expected 1x1.", + i, output_tensor->dims->data[2], + output_tensor->dims->data[1]), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + } else if (num_dimensions != 2) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Unexpected number of dimensions for output index %d: got %dD, " + "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, " + "H=1).", + i, num_dimensions), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + if (output_tensor->dims->data[0] != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("The output array is expected to have a batch size " + "of 1. Got %d for output index %d.", + output_tensor->dims->data[0], i), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + int num_classes = output_tensor->dims->data[num_dimensions - 1]; + // If label map is not set, build a default one based on model + // introspection. This happens if a model with partial or no metadata was + // provided through the `model_file_with_metadata` options field. + if (classification_heads_[i].label_map_items.empty()) { + classification_heads_[i].label_map_items.reserve(num_classes); + for (int class_index = 0; class_index < num_classes; ++class_index) { + classification_heads_[i].label_map_items.emplace_back(LabelMapItem{}); + } + } + int num_label_map_items = classification_heads_[i].label_map_items.size(); + if (num_classes != num_label_map_items) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Got %d class(es) for output index %d, expected %d " + "according to the label map.", + output_tensor->dims->data[num_dimensions - 1], i, + num_label_map_items), + TfLiteSupportStatus::kMetadataInconsistencyError); + } + if (output_tensor->type == kTfLiteUInt8) { + num_quantized_outputs++; + } else if (output_tensor->type != kTfLiteFloat32) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Type mismatch for output tensor %s. Requested one " + "of these types: " + "kTfLiteUint8/kTfLiteFloat32, got %s.", + output_tensor->name, + TfLiteTypeGetName(output_tensor->type)), + TfLiteSupportStatus::kInvalidOutputTensorTypeError); + } + } + + if (num_quantized_outputs > 0 && num_quantized_outputs != num_outputs_) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Got %d quantized output(s), expected %d (i.e. all " + "provided outputs must be quantized).", + num_quantized_outputs, num_outputs_), + TfLiteSupportStatus::kInvalidOutputTensorTypeError); + } + has_uint8_outputs_ = (num_quantized_outputs > 0); + + return absl::OkStatus(); +} + +absl::Status ImageClassifier::CheckAndSetClassNameSet() { + // Exit early if no blacklist/whitelist. + if (options_->class_name_blacklist_size() == 0 && + options_->class_name_whitelist_size() == 0) { + return absl::OkStatus(); + } + + // Before processing class names whitelist or blacklist from the input options + // create a set with _all_ known class names from the label map(s). + absl::flat_hash_set<std::string> all_class_names; + int head_index = 0; + for (const auto& head : classification_heads_) { + absl::flat_hash_set<std::string> head_class_names; + for (const auto& item : head.label_map_items) { + if (!item.name.empty()) { + head_class_names.insert(item.name); + } + } + if (head_class_names.empty()) { + std::string name = head.name; + if (name.empty()) { + name = absl::StrFormat("#%d", head_index); + } + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Using `class_name_whitelist` or `class_name_blacklist` " + "requires labels to be present but none was found for " + "classification head: %s", + name), + TfLiteSupportStatus::kMetadataMissingLabelsError); + } + all_class_names.insert(head_class_names.begin(), head_class_names.end()); + head_index++; + } + + class_name_set_.is_whitelist = options_->class_name_whitelist_size() > 0; + const auto& class_names = class_name_set_.is_whitelist + ? options_->class_name_whitelist() + : options_->class_name_blacklist(); + + // Note: duplicate or unknown classes are just ignored. + class_name_set_.values.clear(); + for (const auto& class_name : class_names) { + if (!all_class_names.contains(class_name)) { + continue; + } + class_name_set_.values.insert(class_name); + } + + if (class_name_set_.values.empty()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Invalid class names specified via `class_name_%s`: none match " + "with model labels.", + class_name_set_.is_whitelist ? "whitelist" : "blacklist"), + TfLiteSupportStatus::kInvalidArgumentError); + } + + return absl::OkStatus(); +} + +absl::Status ImageClassifier::InitScoreCalibrations() { + score_calibrations_.clear(); + score_calibrations_.resize(classification_heads_.size()); + + for (int i = 0; i < classification_heads_.size(); ++i) { + if (!classification_heads_[i].calibration_params.has_value()) { + continue; + } + + // Use a specific default score instead of the one specified by default in + // cc/task/vision/utils/score_calibration.h. See `kDefaultCalibratedScore` + // documentation for more details. + classification_heads_[i].calibration_params->default_score = + kDefaultCalibratedScore; + + score_calibrations_[i] = absl::make_unique<ScoreCalibration>(); + if (score_calibrations_[i] == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInternal, "Could not create score calibration object."); + } + + RETURN_IF_ERROR(score_calibrations_[i]->InitializeFromParameters( + classification_heads_[i].calibration_params.value())); + } + + return absl::OkStatus(); +} + +StatusOr<ClassificationResult> ImageClassifier::Classify( + const FrameBuffer& frame_buffer) { + BoundingBox roi; + roi.set_width(frame_buffer.dimension().width); + roi.set_height(frame_buffer.dimension().height); + return Classify(frame_buffer, roi); +} + +StatusOr<ClassificationResult> ImageClassifier::Classify( + const FrameBuffer& frame_buffer, const BoundingBox& roi) { + return InferWithFallback(frame_buffer, roi); +} + +StatusOr<ClassificationResult> ImageClassifier::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) { + if (output_tensors.size() != num_outputs_) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Expected %d output tensors, found %d", num_outputs_, + output_tensors.size())); + } + + ClassificationResult result; + std::vector<std::pair<int, float>> score_pairs; + + for (int i = 0; i < num_outputs_; ++i) { + auto* classifications = result.add_classifications(); + classifications->set_head_index(i); + + const auto& head = classification_heads_[i]; + score_pairs.clear(); + score_pairs.reserve(head.label_map_items.size()); + + const TfLiteTensor* output_tensor = output_tensors[i]; + if (has_uint8_outputs_) { + const uint8* output_data = + AssertAndReturnTypedTensor<uint8>(output_tensor); + for (int j = 0; j < head.label_map_items.size(); ++j) { + score_pairs.emplace_back(j, output_tensor->params.scale * + (static_cast<int>(output_data[j]) - + output_tensor->params.zero_point)); + } + } else { + const float* output_data = + AssertAndReturnTypedTensor<float>(output_tensor); + for (int j = 0; j < head.label_map_items.size(); ++j) { + score_pairs.emplace_back(j, output_data[j]); + } + } + + // Optional score calibration. + if (score_calibrations_[i] != nullptr) { + for (auto& score_pair : score_pairs) { + const std::string& class_name = + head.label_map_items[score_pair.first].name; + score_pair.second = score_calibrations_[i]->ComputeCalibratedScore( + class_name, score_pair.second); + if (score_pair.second > kMaxCalibratedScore) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("calibrated score is too high: got %f, expected " + "%f as maximum.", + score_pair.second, kMaxCalibratedScore)); + } + if (score_pair.second != kDefaultCalibratedScore && + score_pair.second < kMinCalibratedScore) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("calibrated score is too low: got %f, expected " + "%f as minimum.", + score_pair.second, kMinCalibratedScore)); + } + } + } + + int num_results = + options_->max_results() >= 0 + ? std::min(static_cast<int>(head.label_map_items.size()), + options_->max_results()) + : head.label_map_items.size(); + float score_threshold = options_->has_score_threshold() + ? options_->score_threshold() + : head.score_threshold; + + if (class_name_set_.values.empty()) { + // Partially sort in descending order (higher score is better). + absl::c_partial_sort( + score_pairs, score_pairs.begin() + num_results, + [](const std::pair<int, float>& a, const std::pair<int, float>& b) { + return a.second > b.second; + }); + + for (int j = 0; j < num_results; ++j) { + float score = score_pairs[j].second; + if (score < score_threshold) { + break; + } + auto* cl = classifications->add_classes(); + cl->set_index(score_pairs[j].first); + cl->set_score(score); + } + } else { + // Sort in descending order (higher score is better). + absl::c_sort(score_pairs, [](const std::pair<int, float>& a, + const std::pair<int, float>& b) { + return a.second > b.second; + }); + + for (int j = 0; j < head.label_map_items.size(); ++j) { + float score = score_pairs[j].second; + if (score < score_threshold || + classifications->classes_size() >= num_results) { + break; + } + + const int class_index = score_pairs[j].first; + const std::string& class_name = head.label_map_items[class_index].name; + + bool class_name_found = class_name_set_.values.contains(class_name); + + if ((!class_name_found && class_name_set_.is_whitelist) || + (class_name_found && !class_name_set_.is_whitelist)) { + continue; + } + + auto* cl = classifications->add_classes(); + cl->set_index(class_index); + cl->set_score(score); + } + } + } + + RETURN_IF_ERROR(FillResultsFromLabelMaps(&result)); + + return result; +} + +absl::Status ImageClassifier::FillResultsFromLabelMaps( + ClassificationResult* result) { + for (int i = 0; i < result->classifications_size(); ++i) { + Classifications* classifications = result->mutable_classifications(i); + int head_index = classifications->head_index(); + if (head_index < 0 || head_index >= classification_heads_.size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Invalid head index (%d) with respect to total " + "number of classification heads (%d).", + head_index, classification_heads_.size()), + TfLiteSupportStatus::kMetadataInconsistencyError); + } + const std::vector<LabelMapItem>& label_map_items = + classification_heads_[head_index].label_map_items; + for (int j = 0; j < classifications->classes_size(); ++j) { + Class* current_class = classifications->mutable_classes(j); + int current_class_index = current_class->index(); + if (current_class_index < 0 || + current_class_index >= label_map_items.size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Invalid class index (%d) with respect to label " + "map size (%d) for head #%d.", + current_class_index, label_map_items.size(), + head_index), + TfLiteSupportStatus::kMetadataInconsistencyError); + } + const std::string& name = label_map_items[current_class_index].name; + if (!name.empty()) { + current_class->set_class_name(name); + } + const std::string& display_name = + label_map_items[current_class_index].display_name; + if (!display_name.empty()) { + current_class->set_display_name(display_name); + } + } + } + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/image_classifier.h b/tensorflow_lite_support/cc/task/vision/image_classifier.h new file mode 100644 index 00000000..edd90931 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/image_classifier.h @@ -0,0 +1,182 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_ + +#include <memory> +#include <vector> + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h" +#include "tensorflow_lite_support/cc/task/vision/core/classification_head.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h" + +namespace tflite { +namespace task { +namespace vision { + +// Performs classification on images. +// +// The API expects a TFLite model with optional, but strongly recommended, +// TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - only RGB inputs are supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// At least one output tensor with: +// (kTfLiteUInt8/kTfLiteFloat32) +// - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or +// `[1 x 1 x 1 x N]` +// - optional (but recommended) label map(s) as AssociatedFile-s with type +// TENSOR_AXIS_LABELS, containing one label per line. The first such +// AssociatedFile (if any) is used to fill the `class_name` field of the +// results. The `display_name` field is filled from the AssociatedFile (if +// any) whose locale matches the `display_names_locale` field of the +// `ImageClassifierOptions` used at creation time ("en" by default, i.e. +// English). If none of these are available, only the `index` field of the +// results will be filled. +// +// An example of such model can be found at: +// https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1 +// +// A CLI demo tool is available for easily trying out this API, and provides +// example usage. See: +// examples/task/vision/desktop/image_classifier_demo.cc +class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ImageClassifier from the provided options. A non-default + // OpResolver can be specified in order to support custom Ops or specify a + // subset of built-in Ops. + static tflite::support::StatusOr<std::unique_ptr<ImageClassifier>> + CreateFromOptions( + const ImageClassifierOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + + // Performs actual classification on the provided FrameBuffer. + // + // The FrameBuffer can be of any size and any of the supported formats, i.e. + // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before + // inference in order to (and in this order): + // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to + // the dimensions of the model input tensor, + // - convert it to the colorspace of the input tensor (i.e. RGB, which is the + // only supported colorspace for now), + // - rotate it according to its `Orientation` so that inference is performed + // on an "upright" image. + tflite::support::StatusOr<ClassificationResult> Classify( + const FrameBuffer& frame_buffer); + + // Same as above, except that the classification is performed based on the + // input region of interest. Cropping according to this region of interest is + // prepended to the pre-processing operations. + // + // IMPORTANT: as a consequence of cropping occurring first, the provided + // region of interest is expressed in the unrotated frame of reference + // coordinates system, i.e. in `[0, frame_buffer.width) x [0, + // frame_buffer.height)`, which are the dimensions of the underlying + // `frame_buffer` data before any `Orientation` flag gets applied. Also, the + // region of interest is not clamped, so this method will return a non-ok + // status if the region is out of these bounds. + tflite::support::StatusOr<ClassificationResult> Classify( + const FrameBuffer& frame_buffer, const BoundingBox& roi); + + protected: + // The options used to build this ImageClassifier. + std::unique_ptr<ImageClassifierOptions> options_; + + // The list of classification heads associated with the corresponding output + // tensors. Built from TFLite Model Metadata. + std::vector<ClassificationHead> classification_heads_; + + // Post-processing to transform the raw model outputs into classification + // results. + tflite::support::StatusOr<ClassificationResult> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const FrameBuffer& frame_buffer, const BoundingBox& roi) override; + + // Performs sanity checks on the provided ImageClassifierOptions. + static absl::Status SanityCheckOptions(const ImageClassifierOptions& options); + + // Initializes the ImageClassifier from the provided ImageClassifierOptions, + // whose ownership is transferred to this object. + absl::Status Init(std::unique_ptr<ImageClassifierOptions> options); + + // Performs pre-initialization actions. + virtual absl::Status PreInit(); + // Performs post-initialization actions. + virtual absl::Status PostInit(); + + private: + // Performs sanity checks on the model outputs and extracts their metadata. + absl::Status CheckAndSetOutputs(); + + // Performs sanity checks on the class whitelist/blacklist and forms the class + // name set. + absl::Status CheckAndSetClassNameSet(); + + // Initializes the score calibration parameters based on corresponding TFLite + // Model Metadata, if any. + absl::Status InitScoreCalibrations(); + + // Given a ClassificationResult object containing class indices, fills the + // name and display name from the label map(s). + absl::Status FillResultsFromLabelMaps(ClassificationResult* result); + + // The number of output tensors. This corresponds to the number of + // classification heads. + int num_outputs_; + // Whether the model features quantized inference type (QUANTIZED_UINT8). This + // is currently detected by checking if all output tensors data type is uint8. + bool has_uint8_outputs_; + + // Set of whitelisted or blacklisted class names. + struct ClassNameSet { + absl::flat_hash_set<std::string> values; + bool is_whitelist; + }; + + // Whitelisted or blacklisted class names based on provided options at + // construction time. These are used to filter out results during + // post-processing. + ClassNameSet class_name_set_; + + // List of score calibration parameters, if any. Built from TFLite Model + // Metadata. + std::vector<std::unique_ptr<ScoreCalibration>> score_calibrations_; +}; + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_ diff --git a/tensorflow_lite_support/cc/task/vision/image_segmenter.cc b/tensorflow_lite_support/cc/task/vision/image_segmenter.cc new file mode 100644 index 00000000..4523b662 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/image_segmenter.cc @@ -0,0 +1,427 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h" + +#include <algorithm> + +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { + +using ::absl::StatusCode; +using ::tflite::TensorMetadata; +using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::core::AssertAndReturnTypedTensor; +using ::tflite::task::core::TaskAPIFactory; +using ::tflite::task::core::TfLiteEngine; + +// The maximum number of labels allowed in the labelmap. This is because so far +// segmentation masks are stored with 8 bit per pixel (flattened byte array). +constexpr uint32 kMaxNumClasses = 256; + +// TODO(b/) +// The colormap used to fill `ColoredLabel`-s, as a flattened array of 256 {R, +// G, B} components. +constexpr uint8 kColorMap[768] = { + 0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, + 128, 0, 128, 0, 128, 128, 128, 128, 128, 64, 0, 0, 192, 0, 0, + 64, 128, 0, 192, 128, 0, 64, 0, 128, 192, 0, 128, 64, 128, 128, + 192, 128, 128, 0, 64, 0, 128, 64, 0, 0, 192, 0, 128, 192, 0, + 0, 64, 128, 128, 64, 128, 0, 192, 128, 128, 192, 128, 64, 64, 0, + 192, 64, 0, 64, 192, 0, 192, 192, 0, 64, 64, 128, 192, 64, 128, + 64, 192, 128, 192, 192, 128, 0, 0, 64, 128, 0, 64, 0, 128, 64, + 128, 128, 64, 0, 0, 192, 128, 0, 192, 0, 128, 192, 128, 128, 192, + 64, 0, 64, 192, 0, 64, 64, 128, 64, 192, 128, 64, 64, 0, 192, + 192, 0, 192, 64, 128, 192, 192, 128, 192, 0, 64, 64, 128, 64, 64, + 0, 192, 64, 128, 192, 64, 0, 64, 192, 128, 64, 192, 0, 192, 192, + 128, 192, 192, 64, 64, 64, 192, 64, 64, 64, 192, 64, 192, 192, 64, + 64, 64, 192, 192, 64, 192, 64, 192, 192, 192, 192, 192, 32, 0, 0, + 160, 0, 0, 32, 128, 0, 160, 128, 0, 32, 0, 128, 160, 0, 128, + 32, 128, 128, 160, 128, 128, 96, 0, 0, 224, 0, 0, 96, 128, 0, + 224, 128, 0, 96, 0, 128, 224, 0, 128, 96, 128, 128, 224, 128, 128, + 32, 64, 0, 160, 64, 0, 32, 192, 0, 160, 192, 0, 32, 64, 128, + 160, 64, 128, 32, 192, 128, 160, 192, 128, 96, 64, 0, 224, 64, 0, + 96, 192, 0, 224, 192, 0, 96, 64, 128, 224, 64, 128, 96, 192, 128, + 224, 192, 128, 32, 0, 64, 160, 0, 64, 32, 128, 64, 160, 128, 64, + 32, 0, 192, 160, 0, 192, 32, 128, 192, 160, 128, 192, 96, 0, 64, + 224, 0, 64, 96, 128, 64, 224, 128, 64, 96, 0, 192, 224, 0, 192, + 96, 128, 192, 224, 128, 192, 32, 64, 64, 160, 64, 64, 32, 192, 64, + 160, 192, 64, 32, 64, 192, 160, 64, 192, 32, 192, 192, 160, 192, 192, + 96, 64, 64, 224, 64, 64, 96, 192, 64, 224, 192, 64, 96, 64, 192, + 224, 64, 192, 96, 192, 192, 224, 192, 192, 0, 32, 0, 128, 32, 0, + 0, 160, 0, 128, 160, 0, 0, 32, 128, 128, 32, 128, 0, 160, 128, + 128, 160, 128, 64, 32, 0, 192, 32, 0, 64, 160, 0, 192, 160, 0, + 64, 32, 128, 192, 32, 128, 64, 160, 128, 192, 160, 128, 0, 96, 0, + 128, 96, 0, 0, 224, 0, 128, 224, 0, 0, 96, 128, 128, 96, 128, + 0, 224, 128, 128, 224, 128, 64, 96, 0, 192, 96, 0, 64, 224, 0, + 192, 224, 0, 64, 96, 128, 192, 96, 128, 64, 224, 128, 192, 224, 128, + 0, 32, 64, 128, 32, 64, 0, 160, 64, 128, 160, 64, 0, 32, 192, + 128, 32, 192, 0, 160, 192, 128, 160, 192, 64, 32, 64, 192, 32, 64, + 64, 160, 64, 192, 160, 64, 64, 32, 192, 192, 32, 192, 64, 160, 192, + 192, 160, 192, 0, 96, 64, 128, 96, 64, 0, 224, 64, 128, 224, 64, + 0, 96, 192, 128, 96, 192, 0, 224, 192, 128, 224, 192, 64, 96, 64, + 192, 96, 64, 64, 224, 64, 192, 224, 64, 64, 96, 192, 192, 96, 192, + 64, 224, 192, 192, 224, 192, 32, 32, 0, 160, 32, 0, 32, 160, 0, + 160, 160, 0, 32, 32, 128, 160, 32, 128, 32, 160, 128, 160, 160, 128, + 96, 32, 0, 224, 32, 0, 96, 160, 0, 224, 160, 0, 96, 32, 128, + 224, 32, 128, 96, 160, 128, 224, 160, 128, 32, 96, 0, 160, 96, 0, + 32, 224, 0, 160, 224, 0, 32, 96, 128, 160, 96, 128, 32, 224, 128, + 160, 224, 128, 96, 96, 0, 224, 96, 0, 96, 224, 0, 224, 224, 0, + 96, 96, 128, 224, 96, 128, 96, 224, 128, 224, 224, 128, 32, 32, 64, + 160, 32, 64, 32, 160, 64, 160, 160, 64, 32, 32, 192, 160, 32, 192, + 32, 160, 192, 160, 160, 192, 96, 32, 64, 224, 32, 64, 96, 160, 64, + 224, 160, 64, 96, 32, 192, 224, 32, 192, 96, 160, 192, 224, 160, 192, + 32, 96, 64, 160, 96, 64, 32, 224, 64, 160, 224, 64, 32, 96, 192, + 160, 96, 192, 32, 224, 192, 160, 224, 192, 96, 96, 64, 224, 96, 64, + 96, 224, 64, 224, 224, 64, 96, 96, 192, 224, 96, 192, 96, 224, 192, + 224, 224, 192}; + +StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata, absl::string_view locale) { + const std::string labels_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS); + if (labels_filename.empty()) { + return std::vector<LabelMapItem>(); + } + ASSIGN_OR_RETURN(absl::string_view labels_file, + metadata_extractor.GetAssociatedFile(labels_filename)); + const std::string display_names_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + tensor_metadata, tflite::AssociatedFileType_TENSOR_AXIS_LABELS, + locale); + absl::string_view display_names_file = nullptr; + if (!display_names_filename.empty()) { + ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile( + display_names_filename)); + } + return BuildLabelMapFromFiles(labels_file, display_names_file); +} + +} // namespace + +/* static */ +absl::Status ImageSegmenter::SanityCheckOptions( + const ImageSegmenterOptions& options) { + if (!options.has_model_file_with_metadata()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Missing mandatory `model_file_with_metadata` field", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options.output_type() == ImageSegmenterOptions::UNSPECIFIED) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "ImageSegmenterOptions: `output_type` must not be UNSPECIFIED", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options.num_threads() == 0 || options.num_threads() < -1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "`num_threads` must be greater than 0 or equal to -1.", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::CreateFromOptions( + const ImageSegmenterOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + // Copy options to ensure the ExternalFile outlives the constructed object. + auto options_copy = absl::make_unique<ImageSegmenterOptions>(options); + + ASSIGN_OR_RETURN(auto image_segmenter, + TaskAPIFactory::CreateFromExternalFileProto<ImageSegmenter>( + &options_copy->model_file_with_metadata(), + std::move(resolver), options_copy->num_threads())); + + RETURN_IF_ERROR(image_segmenter->Init(std::move(options_copy))); + + return image_segmenter; +} + +absl::Status ImageSegmenter::Init( + std::unique_ptr<ImageSegmenterOptions> options) { + // Set options. + options_ = std::move(options); + + // Perform pre-initialization actions (by default, sets the process engine for + // image pre-processing to kLibyuv as a sane default). + RETURN_IF_ERROR(PreInit()); + + // Sanity check and set inputs and outputs. + RETURN_IF_ERROR(CheckAndSetInputs()); + RETURN_IF_ERROR(CheckAndSetOutputs()); + + // Initialize colored_labels_ once and for all. + RETURN_IF_ERROR(InitColoredLabels()); + + return absl::OkStatus(); +} + +absl::Status ImageSegmenter::PreInit() { + SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv); + return absl::OkStatus(); +} + +absl::Status ImageSegmenter::CheckAndSetOutputs() { + // First, sanity checks on the model itself. + const TfLiteEngine::Interpreter* interpreter = engine_->interpreter(); + + // Check the number of output tensors. + if (TfLiteEngine::OutputCount(interpreter) != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Image segmentation models are expected to have only 1 " + "output, found %d", + TfLiteEngine::OutputCount(interpreter)), + TfLiteSupportStatus::kInvalidNumOutputTensorsError); + } + const TfLiteTensor* output_tensor = TfLiteEngine::GetOutput(interpreter, 0); + + // Check tensor dimensions. + if (output_tensor->dims->size != 4) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Output tensor is expected to have 4 dimensions, found %d.", + output_tensor->dims->size), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + if (output_tensor->dims->data[0] != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected batch size of 1, found %d.", + output_tensor->dims->data[0]), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + output_height_ = output_tensor->dims->data[1]; + output_width_ = output_tensor->dims->data[2]; + output_depth_ = output_tensor->dims->data[3]; + if (output_depth_ > kMaxNumClasses) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected at most %d output classes, found %d", + kMaxNumClasses, output_depth_), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + + // Check tensor type. + if (output_tensor->type != kTfLiteFloat32 && + output_tensor->type != kTfLiteUInt8) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Type mismatch for output tensor. Requested one of " + "these types: kTfLiteUint8/kTfLiteFloat32, got %s.", + TfLiteTypeGetName(output_tensor->type)), + TfLiteSupportStatus::kInvalidOutputTensorTypeError); + } + has_uint8_outputs_ = (output_tensor->type == kTfLiteUInt8); + + // Build label map from metadata, if available. + const ModelMetadataExtractor* metadata_extractor = + engine_->metadata_extractor(); + const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>* + output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata(); + if (output_tensor_metadata != nullptr) { + // Check metadata consistency. + if (output_tensor_metadata->size() != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of output tensors (1) and " + "output tensors metadata (%d).", + output_tensor_metadata->size()), + TfLiteSupportStatus::kMetadataInconsistencyError); + } + ASSIGN_OR_RETURN( + label_map_, + GetLabelMapIfAny(*metadata_extractor, *output_tensor_metadata->Get(0), + options_->display_names_locale())); + } + + // If label map is still empty, build a default one. + if (label_map_.empty()) { + for (int class_index = 0; class_index < output_depth_; ++class_index) { + label_map_.emplace_back(LabelMapItem{}); + } + } + + return absl::OkStatus(); +} + +absl::Status ImageSegmenter::InitColoredLabels() { + for (int i = 0; i < label_map_.size(); ++i) { + Segmentation::ColoredLabel colored_label; + colored_label.set_r(kColorMap[3 * i]); + colored_label.set_g(kColorMap[3 * i + 1]); + colored_label.set_b(kColorMap[3 * i + 2]); + const LabelMapItem& item = label_map_[i]; + if (!item.name.empty()) { + colored_label.set_class_name(item.name); + } + if (!item.display_name.empty()) { + colored_label.set_display_name(item.display_name); + } + colored_labels_.push_back(colored_label); + } + return absl::OkStatus(); +} + +StatusOr<SegmentationResult> ImageSegmenter::Segment( + const FrameBuffer& frame_buffer) { + BoundingBox roi; + roi.set_width(frame_buffer.dimension().width); + roi.set_height(frame_buffer.dimension().height); + return InferWithFallback(frame_buffer, roi); +} + +StatusOr<SegmentationResult> ImageSegmenter::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) { + if (output_tensors.size() != 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Expected 1 output tensors, found %d", + output_tensors.size())); + } + const TfLiteTensor* output_tensor = output_tensors[0]; + + SegmentationResult result; + Segmentation* segmentation = result.add_segmentation(); + *segmentation->mutable_colored_labels() = {colored_labels_.begin(), + colored_labels_.end()}; + + // The output tensor has orientation `frame_buffer.orientation()`, as it has + // been produced from the pre-processed frame. + FrameBuffer::Orientation tensor_orientation = frame_buffer.orientation(); + // The output tensor always has size `output_width_ x output_height_` + FrameBuffer::Dimension tensor_dimension = {output_width_, output_height_}; + + // The masks to produce from the output tensor need to be re-oriented in the + // unrotated frame of reference coordinates system, i.e. kTopLeft. + FrameBuffer::Orientation mask_orientation = + FrameBuffer::Orientation::kTopLeft; + // They may thus have swapped dimensions compared to the tensor if the + // rotation is 90° or 270°. + FrameBuffer::Dimension mask_dimension(tensor_dimension); + if (RequireDimensionSwap(frame_buffer.orientation(), + FrameBuffer::Orientation::kTopLeft)) { + mask_dimension.Swap(); + } + segmentation->set_width(mask_dimension.width); + segmentation->set_height(mask_dimension.height); + + // XY coordinates in the tensor, to be computed from mask_x and mask_y below. + int tensor_x; + int tensor_y; + + if (options_->output_type() == ImageSegmenterOptions::CATEGORY_MASK) { + auto* category_mask = segmentation->mutable_category_mask(); + category_mask->resize(mask_dimension.width * mask_dimension.height); + int pixel_offset = 0; + for (int mask_y = 0; mask_y < mask_dimension.height; ++mask_y) { + for (int mask_x = 0; mask_x < mask_dimension.width; ++mask_x) { + // Compute the coordinates (tensor_x, tensor_y) in the tensor with + // tensor_orientation = frame_buffer.orientation() corresponding to the + // coordinates (mask_x, mask_y) in the mask being filled with + // mask_orientation = kTopLeft, i.e. the orientation of the unrotated + // frame of reference. + OrientCoordinates(/*from_x=*/mask_x, + /*from_y=*/mask_y, + /*from_orientation=*/mask_orientation, + /*to_orientation=*/tensor_orientation, + /*from_dimension=*/mask_dimension, + /*to_x=*/&tensor_x, + /*to_y=*/&tensor_y); + int class_index = 0; + float max_confidence = 0.0f; + for (int d = 0; d < output_depth_; ++d) { + const float confidence = + GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d); + if (confidence > max_confidence) { + class_index = d; + max_confidence = confidence; + } + } + (*category_mask)[pixel_offset++] = static_cast<char>(class_index); + } + } + } else if (options_->output_type() == + ImageSegmenterOptions::CONFIDENCE_MASK) { + auto* confidence_masks = segmentation->mutable_confidence_masks(); + for (int d = 0; d < output_depth_; ++d) { + confidence_masks->add_confidence_mask(); + } + for (int mask_y = 0; mask_y < segmentation->height(); ++mask_y) { + for (int mask_x = 0; mask_x < segmentation->width(); ++mask_x) { + // See above. + OrientCoordinates(/*from_x=*/mask_x, + /*from_y=*/mask_y, + /*from_orientation=*/mask_orientation, + /*to_orientation=*/tensor_orientation, + /*from_dimension=*/mask_dimension, + /*to_x=*/&tensor_x, + /*to_y=*/&tensor_y); + for (int d = 0; d < output_depth_; ++d) { + confidence_masks->mutable_confidence_mask(d)->add_value( + GetOutputConfidence(*output_tensor, tensor_x, tensor_y, d)); + } + } + } + } + + return result; +} + +float ImageSegmenter::GetOutputConfidence(const TfLiteTensor& output_tensor, + int x, int y, int depth) { + int index = output_width_ * output_depth_ * y + output_depth_ * x + depth; + if (has_uint8_outputs_) { + const uint8* data = AssertAndReturnTypedTensor<uint8>(&output_tensor); + return output_tensor.params.scale * + (static_cast<int>(data[index]) - output_tensor.params.zero_point); + } else { + const float* data = AssertAndReturnTypedTensor<float>(&output_tensor); + return data[index]; + } +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/image_segmenter.h b/tensorflow_lite_support/cc/task/vision/image_segmenter.h new file mode 100644 index 00000000..663ddb70 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/image_segmenter.h @@ -0,0 +1,172 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_ + +#include <memory> +#include <vector> + +#include "absl/status/status.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h" + +namespace tflite { +namespace task { +namespace vision { + +// Performs segmentation on images. +// +// The API expects a TFLite model with optional, but strongly recommended, +// TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - only RGB inputs are supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// Output tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - tensor of size `[batch x mask_height x mask_width x num_classes]`, where +// `batch` is required to be 1, `mask_width` and `mask_height` are the +// dimensions of the segmentation masks produced by the model, and +// `num_classes` is the number of classes supported by the model. +// - optional (but recommended) label map(s) can be attached as +// AssociatedFile-s with type TENSOR_AXIS_LABELS, containing one label per +// line. The first such AssociatedFile (if any) is used to fill the +// `class_name` field of the results. The `display_name` field is filled +// from the AssociatedFile (if any) whose locale matches the +// `display_names_locale` field of the `ImageSegmenterOptions` used at +// creation time ("en" by default, i.e. English). If none of these are +// available, only the `index` field of the results will be filled. +// +// An example of such model can be found at: +// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1 +// +// A CLI demo tool is available for easily trying out this API, and provides +// example usage. See: +// examples/task/vision/desktop/image_segmenter_demo.cc +class ImageSegmenter : public BaseVisionTaskApi<SegmentationResult> { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ImageSegmenter from the provided options. A non-default + // OpResolver can be specified in order to support custom Ops or specify a + // subset of built-in Ops. + static tflite::support::StatusOr<std::unique_ptr<ImageSegmenter>> + CreateFromOptions( + const ImageSegmenterOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + + // Performs actual segmentation on the provided FrameBuffer. + // + // The FrameBuffer can be of any size and any of the supported formats, i.e. + // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before + // inference in order to (and in this order): + // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to + // the dimensions of the model input tensor, + // - convert it to the colorspace of the input tensor (i.e. RGB, which is the + // only supported colorspace for now), + // - rotate it according to its `Orientation` so that inference is performed + // on an "upright" image. + // + // IMPORTANT: the returned segmentation masks are not direcly suited for + // display, in particular: + // * they are relative to the unrotated input frame, i.e. *not* taking into + // account the `Orientation` flag of the input FrameBuffer, + // * their dimensions are intrinsic to the model, i.e. *not* dependent on the + // input FrameBuffer dimensions. + // + // Example of such post-processing, assuming: + // * an input FrameBuffer with width=640, height=480, orientation=kLeftBottom + // (i.e. the image will be rotated 90° clockwise during preprocessing to + // make it "upright"), + // * a model outputting masks of size 224x224. + // In order to be directly displayable on top of the input image assumed to + // be displayed *with* the `Orientation` flag taken into account according to + // the EXIF specification (http://jpegclub.org/exif_orientation.html), the + // masks need to be: + // * re-scaled to 640 x 480, + // * then rotated 90° clockwise. + tflite::support::StatusOr<SegmentationResult> Segment( + const FrameBuffer& frame_buffer); + + protected: + // Post-processing to transform the raw model outputs into segmentation + // results. + tflite::support::StatusOr<SegmentationResult> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const FrameBuffer& frame_buffer, const BoundingBox& roi) override; + + // Performs sanity checks on the provided ImageSegmenterOptions. + static absl::Status SanityCheckOptions(const ImageSegmenterOptions& options); + + // Initializes the Segmenter from the provided ImageSegmenterOptions, whose + // ownership is transferred to this object. + absl::Status Init(std::unique_ptr<ImageSegmenterOptions> options); + + // Performs pre-initialization actions. + virtual absl::Status PreInit(); + + // The options used for building this image segmenter. + std::unique_ptr<ImageSegmenterOptions> options_; + + // The label map, extracted from the TFLite Model Metadata. + std::vector<LabelMapItem> label_map_; + + private: + // Performs sanity checks on the model outputs and extracts their metadata. + absl::Status CheckAndSetOutputs(); + + // Initializes the colored labels list from `label_map_` and stores it in + // `colored_labels_`. + absl::Status InitColoredLabels(); + + // Returns the output confidence at coordinates {x, y, depth}, dequantizing + // on-the-fly if needed (i.e. if `has_uint8_outputs_` is true). + float GetOutputConfidence(const TfLiteTensor& output_tensor, int x, int y, + int depth); + + // Prebuilt list of ColoredLabel attached to each Segmentation result. The + // i-th item in this list corresponds to the i-th label map item. + std::vector<Segmentation::ColoredLabel> colored_labels_; + + // Whether the model features quantized inference type (QUANTIZED_UINT8). This + // is currently detected by checking if all output tensors data type is uint8. + bool has_uint8_outputs_; + + // Expected output width. + int output_width_; + // Expected output height. + int output_height_; + // Expected output depth. This corresponds to the number of supported classes. + int output_depth_; +}; + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_SEGMENTER_H_ diff --git a/tensorflow_lite_support/cc/task/vision/object_detector.cc b/tensorflow_lite_support/cc/task/vision/object_detector.cc new file mode 100644 index 00000000..22ec3019 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/object_detector.cc @@ -0,0 +1,549 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/object_detector.h" + +#include <algorithm> +#include <limits> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/task_api_factory.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { + +using ::absl::StatusCode; +using ::tflite::BoundingBoxProperties; +using ::tflite::ContentProperties; +using ::tflite::ContentProperties_BoundingBoxProperties; +using ::tflite::EnumNameContentProperties; +using ::tflite::ProcessUnit; +using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions; +using ::tflite::TensorMetadata; +using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::core::AssertAndReturnTypedTensor; +using ::tflite::task::core::TaskAPIFactory; +using ::tflite::task::core::TfLiteEngine; + +// The expected number of dimensions of the 4 output tensors, representing in +// that order: locations, classes, scores, num_results. +static constexpr int kOutputTensorsExpectedDims[4] = {3, 2, 2, 1}; + +StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties( + const TensorMetadata& tensor_metadata) { + if (tensor_metadata.content() == nullptr || + tensor_metadata.content()->content_properties() == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected BoundingBoxProperties for tensor %s, found none.", + tensor_metadata.name() ? tensor_metadata.name()->str() : "#0"), + TfLiteSupportStatus::kMetadataInvalidContentPropertiesError); + } + + ContentProperties type = tensor_metadata.content()->content_properties_type(); + if (type != ContentProperties_BoundingBoxProperties) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected BoundingBoxProperties for tensor %s, found %s.", + tensor_metadata.name() ? tensor_metadata.name()->str() : "#0", + EnumNameContentProperties(type)), + TfLiteSupportStatus::kMetadataInvalidContentPropertiesError); + } + + const BoundingBoxProperties* properties = + tensor_metadata.content()->content_properties_as_BoundingBoxProperties(); + + // Mobile SSD only supports "BOUNDARIES" bounding box type. + if (properties->type() != tflite::BoundingBoxType_BOUNDARIES) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Mobile SSD only supports BoundingBoxType BOUNDARIES, found %s", + tflite::EnumNameBoundingBoxType(properties->type())), + TfLiteSupportStatus::kMetadataInvalidContentPropertiesError); + } + + // Mobile SSD only supports "RATIO" coordinates type. + if (properties->coordinate_type() != tflite::CoordinateType_RATIO) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Mobile SSD only supports CoordinateType RATIO, found %s", + tflite::EnumNameCoordinateType(properties->coordinate_type())), + TfLiteSupportStatus::kMetadataInvalidContentPropertiesError); + } + + // Index is optional, but must contain 4 values if present. + if (properties->index() != nullptr && properties->index()->size() != 4) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Expected BoundingBoxProperties index to contain 4 values, found " + "%d", + properties->index()->size()), + TfLiteSupportStatus::kMetadataInvalidContentPropertiesError); + } + + return properties; +} + +StatusOr<std::vector<LabelMapItem>> GetLabelMapIfAny( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata, absl::string_view locale) { + const std::string labels_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS); + if (labels_filename.empty()) { + return std::vector<LabelMapItem>(); + } + ASSIGN_OR_RETURN(absl::string_view labels_file, + metadata_extractor.GetAssociatedFile(labels_filename)); + const std::string display_names_filename = + ModelMetadataExtractor::FindFirstAssociatedFileName( + tensor_metadata, tflite::AssociatedFileType_TENSOR_VALUE_LABELS, + locale); + absl::string_view display_names_file = nullptr; + if (!display_names_filename.empty()) { + ASSIGN_OR_RETURN(display_names_file, metadata_extractor.GetAssociatedFile( + display_names_filename)); + } + return BuildLabelMapFromFiles(labels_file, display_names_file); +} + +StatusOr<float> GetScoreThreshold( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata& tensor_metadata) { + ASSIGN_OR_RETURN( + const ProcessUnit* score_thresholding_process_unit, + metadata_extractor.FindFirstProcessUnit( + tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions)); + if (score_thresholding_process_unit == nullptr) { + return std::numeric_limits<float>::lowest(); + } + return score_thresholding_process_unit->options_as_ScoreThresholdingOptions() + ->global_score_threshold(); +} + +absl::Status SanityCheckOutputTensors( + const std::vector<const TfLiteTensor*>& output_tensors) { + if (output_tensors.size() != 4) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Expected 4 output tensors, found %d", + output_tensors.size())); + } + + // Get number of results. + if (output_tensors[3]->dims->data[0] != 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat( + "Expected tensor with dimensions [1] at index 3, found [%d]", + output_tensors[3]->dims->data[0])); + } + int num_results = + static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]); + + // Check dimensions for the other tensors are correct. + if (output_tensors[0]->dims->data[0] != 1 || + output_tensors[0]->dims->data[1] != num_results || + output_tensors[0]->dims->data[2] != 4) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat( + "Expected locations tensor with dimensions [1,%d,4] at index 0, " + "found [%d,%d,%d].", + num_results, output_tensors[0]->dims->data[0], + output_tensors[0]->dims->data[1], + output_tensors[0]->dims->data[2])); + } + if (output_tensors[1]->dims->data[0] != 1 || + output_tensors[1]->dims->data[1] != num_results) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat( + "Expected classes tensor with dimensions [1,%d] at index 1, " + "found [%d,%d].", + num_results, output_tensors[1]->dims->data[0], + output_tensors[1]->dims->data[1])); + } + if (output_tensors[2]->dims->data[0] != 1 || + output_tensors[2]->dims->data[1] != num_results) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat( + "Expected scores tensor with dimensions [1,%d] at index 2, " + "found [%d,%d].", + num_results, output_tensors[2]->dims->data[0], + output_tensors[2]->dims->data[1])); + } + + return absl::OkStatus(); +} + +} // namespace + +/* static */ +absl::Status ObjectDetector::SanityCheckOptions( + const ObjectDetectorOptions& options) { + if (!options.has_model_file_with_metadata()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Missing mandatory `model_file_with_metadata` field", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options.max_results() == 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Invalid `max_results` option: value must be != 0", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options.class_name_whitelist_size() > 0 && + options.class_name_blacklist_size() > 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "`class_name_whitelist` and `class_name_blacklist` are mutually " + "exclusive options.", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (options.num_threads() == 0 || options.num_threads() < -1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "`num_threads` must be greater than 0 or equal to -1.", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +/* static */ +StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::CreateFromOptions( + const ObjectDetectorOptions& options, + std::unique_ptr<tflite::OpResolver> resolver) { + RETURN_IF_ERROR(SanityCheckOptions(options)); + + // Copy options to ensure the ExternalFile outlives the constructed object. + auto options_copy = absl::make_unique<ObjectDetectorOptions>(options); + + ASSIGN_OR_RETURN(auto object_detector, + TaskAPIFactory::CreateFromExternalFileProto<ObjectDetector>( + &options_copy->model_file_with_metadata(), + std::move(resolver), options_copy->num_threads())); + + RETURN_IF_ERROR(object_detector->Init(std::move(options_copy))); + + return object_detector; +} + +absl::Status ObjectDetector::Init( + std::unique_ptr<ObjectDetectorOptions> options) { + // Set options. + options_ = std::move(options); + + // Perform pre-initialization actions (by default, sets the process engine for + // image pre-processing to kLibyuv as a sane default). + RETURN_IF_ERROR(PreInit()); + + // Sanity check and set inputs and outputs. + RETURN_IF_ERROR(CheckAndSetInputs()); + RETURN_IF_ERROR(CheckAndSetOutputs()); + + // Initialize class whitelisting/blacklisting, if any. + RETURN_IF_ERROR(CheckAndSetClassIndexSet()); + + return absl::OkStatus(); +} + +absl::Status ObjectDetector::PreInit() { + SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv); + return absl::OkStatus(); +} + +absl::Status ObjectDetector::CheckAndSetOutputs() { + // First, sanity checks on the model itself. + const TfLiteEngine::Interpreter* interpreter = engine_->interpreter(); + // Check the number of output tensors. + if (TfLiteEngine::OutputCount(interpreter) != 4) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Mobile SSD models are expected to have exactly 4 " + "outputs, found %d", + TfLiteEngine::OutputCount(interpreter)), + TfLiteSupportStatus::kInvalidNumOutputTensorsError); + } + // Check tensor dimensions and batch size. + for (int i = 0; i < 4; ++i) { + const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, i); + if (tensor->dims->size != kOutputTensorsExpectedDims[i]) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Output tensor at index %d is expected to " + "have %d dimensions, found %d.", + i, kOutputTensorsExpectedDims[i], tensor->dims->size), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + if (tensor->dims->data[0] != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected batch size of 1, found %d.", + tensor->dims->data[0]), + TfLiteSupportStatus::kInvalidOutputTensorDimensionsError); + } + } + + // Now, perform sanity checks and extract metadata. + const ModelMetadataExtractor* metadata_extractor = + engine_->metadata_extractor(); + // Check that metadata is available. + if (metadata_extractor->GetModelMetadata() == nullptr || + metadata_extractor->GetModelMetadata()->subgraph_metadata() == nullptr) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Object detection models require TFLite " + "Model Metadata but none was found", + TfLiteSupportStatus::kMetadataNotFoundError); + } + // Check output tensor metadata is present and consistent with model. + auto output_tensors_metadata = metadata_extractor->GetOutputTensorMetadata(); + if (output_tensors_metadata == nullptr || + output_tensors_metadata->size() != 4) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Mismatch between number of output tensors (4) and output tensors " + "metadata (%d).", + output_tensors_metadata == nullptr + ? 0 + : output_tensors_metadata->size()), + TfLiteSupportStatus::kMetadataInconsistencyError); + } + + // Extract mandatory BoundingBoxProperties for easier access at + // post-processing time, performing sanity checks on the fly. + ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties, + GetBoundingBoxProperties(*output_tensors_metadata->Get(0))); + if (bounding_box_properties->index() == nullptr) { + bounding_box_corners_order_ = {0, 1, 2, 3}; + } else { + auto bounding_box_index = bounding_box_properties->index(); + bounding_box_corners_order_ = { + bounding_box_index->Get(0), + bounding_box_index->Get(1), + bounding_box_index->Get(2), + bounding_box_index->Get(3), + }; + } + + // Build label map (if available) from metadata. + ASSIGN_OR_RETURN( + label_map_, + GetLabelMapIfAny(*metadata_extractor, *output_tensors_metadata->Get(1), + options_->display_names_locale())); + + // Set score threshold. + if (options_->has_score_threshold()) { + score_threshold_ = options_->score_threshold(); + } else { + ASSIGN_OR_RETURN(score_threshold_, + GetScoreThreshold(*metadata_extractor, + *output_tensors_metadata->Get(2))); + } + + return absl::OkStatus(); +} + +absl::Status ObjectDetector::CheckAndSetClassIndexSet() { + // Exit early if no blacklist/whitelist. + if (options_->class_name_blacklist_size() == 0 && + options_->class_name_whitelist_size() == 0) { + return absl::OkStatus(); + } + // Label map is mandatory. + if (label_map_.empty()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Using `class_name_whitelist` or `class_name_blacklist` requires " + "labels to be present in the TFLite Model Metadata but none was found.", + TfLiteSupportStatus::kMetadataMissingLabelsError); + } + + class_index_set_.is_whitelist = options_->class_name_whitelist_size() > 0; + const auto& class_names = class_index_set_.is_whitelist + ? options_->class_name_whitelist() + : options_->class_name_blacklist(); + class_index_set_.values.clear(); + for (const auto& class_name : class_names) { + int index = -1; + for (int i = 0; i < label_map_.size(); ++i) { + if (label_map_[i].name == class_name) { + index = i; + break; + } + } + // Ignore duplicate or unknown classes. + if (index < 0 || class_index_set_.values.contains(index)) { + continue; + } + class_index_set_.values.insert(index); + } + + if (class_index_set_.values.empty()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Invalid class names specified via `class_name_%s`: none match " + "with model labels.", + class_index_set_.is_whitelist ? "whitelist" : "blacklist"), + TfLiteSupportStatus::kInvalidArgumentError); + } + + return absl::OkStatus(); +} + +StatusOr<DetectionResult> ObjectDetector::Detect( + const FrameBuffer& frame_buffer) { + BoundingBox roi; + roi.set_width(frame_buffer.dimension().width); + roi.set_height(frame_buffer.dimension().height); + // Rely on `Infer` instead of `InferWithFallback` as DetectionPostprocessing + // op doesn't support hardware acceleration at the time. + return Infer(frame_buffer, roi); +} + +StatusOr<DetectionResult> ObjectDetector::Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const FrameBuffer& frame_buffer, const BoundingBox& /*roi*/) { + // Most of the checks here should never happen, as outputs have been validated + // at construction time. Checking nonetheless and returning internal errors if + // something bad happens. + RETURN_IF_ERROR(SanityCheckOutputTensors(output_tensors)); + + // Get number of available results. + const int num_results = + static_cast<int>(AssertAndReturnTypedTensor<float>(output_tensors[3])[0]); + // Compute number of max results to return. + const int max_results = options_->max_results() > 0 + ? std::min(options_->max_results(), num_results) + : num_results; + // The dimensions of the upright (i.e. rotated according to its orientation) + // input frame. + FrameBuffer::Dimension upright_input_frame_dimensions = + frame_buffer.dimension(); + if (RequireDimensionSwap(frame_buffer.orientation(), + FrameBuffer::Orientation::kTopLeft)) { + upright_input_frame_dimensions.Swap(); + } + + const float* locations = AssertAndReturnTypedTensor<float>(output_tensors[0]); + const float* classes = AssertAndReturnTypedTensor<float>(output_tensors[1]); + const float* scores = AssertAndReturnTypedTensor<float>(output_tensors[2]); + DetectionResult results; + for (int i = 0; i < num_results; ++i) { + const int class_index = static_cast<int>(classes[i]); + const float score = scores[i]; + if (!IsClassIndexAllowed(class_index) || score < score_threshold_) { + continue; + } + Detection* detection = results.add_detections(); + // Denormalize the bounding box cooordinates in the upright frame + // coordinates system, then rotate back from frame_buffer.orientation() to + // the unrotated frame of reference coordinates system (i.e. with + // orientation = kTopLeft). + *detection->mutable_bounding_box() = OrientAndDenormalizeBoundingBox( + /*from_left=*/locations[4 * i + bounding_box_corners_order_[0]], + /*from_top=*/locations[4 * i + bounding_box_corners_order_[1]], + /*from_right=*/locations[4 * i + bounding_box_corners_order_[2]], + /*from_bottom=*/locations[4 * i + bounding_box_corners_order_[3]], + /*from_orientation=*/frame_buffer.orientation(), + /*to_orientation=*/FrameBuffer::Orientation::kTopLeft, + /*from_dimension=*/upright_input_frame_dimensions); + Class* detection_class = detection->add_classes(); + detection_class->set_index(class_index); + detection_class->set_score(score); + if (results.detections_size() == max_results) { + break; + } + } + + if (!label_map_.empty()) { + RETURN_IF_ERROR(FillResultsFromLabelMap(&results)); + } + + return results; +} + +bool ObjectDetector::IsClassIndexAllowed(int class_index) { + if (class_index_set_.values.empty()) { + return true; + } + if (class_index_set_.is_whitelist) { + return class_index_set_.values.contains(class_index); + } else { + return !class_index_set_.values.contains(class_index); + } +} + +absl::Status ObjectDetector::FillResultsFromLabelMap(DetectionResult* result) { + for (int i = 0; i < result->detections_size(); ++i) { + Detection* detection = result->mutable_detections(i); + for (int j = 0; j < detection->classes_size(); ++j) { + Class* detection_class = detection->mutable_classes(j); + const int index = detection_class->index(); + if (index >= label_map_.size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Label map does not contain enough elements: model returned " + "class index %d but label map only contains %d elements.", + index, label_map_.size()), + TfLiteSupportStatus::kMetadataInconsistencyError); + } + std::string name = label_map_[index].name; + if (!name.empty()) { + detection_class->set_class_name(name); + } + std::string display_name = label_map_[index].display_name; + if (!display_name.empty()) { + detection_class->set_display_name(display_name); + } + } + } + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/object_detector.h b/tensorflow_lite_support/cc/task/vision/object_detector.h new file mode 100644 index 00000000..2bd220b3 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/object_detector.h @@ -0,0 +1,186 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_OBJECT_DETECTOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_OBJECT_DETECTOR_H_ + +#include <memory> + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" +#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h" + +namespace tflite { +namespace task { +namespace vision { + +// Performs object detection on images. +// +// The API expects a TFLite model with mandatory TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - only RGB inputs are supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// Output tensors must be the 4 outputs of a `DetectionPostProcess` op, i.e: +// (kTfLiteFloat32) +// - locations tensor of size `[num_results x 4]`, the inner array +// representing bounding boxes in the form [top, left, right, bottom]. +// - BoundingBoxProperties are required to be attached to the metadata +// and must specify type=BOUNDARIES and coordinate_type=RATIO. +// (kTfLiteFloat32) +// - classes tensor of size `[num_results]`, each value representing the +// integer index of a class. +// - optional (but recommended) label map(s) can be attached as +// AssociatedFile-s with type TENSOR_VALUE_LABELS, containing one label per +// line. The first such AssociatedFile (if any) is used to fill the +// `class_name` field of the results. The `display_name` field is filled +// from the AssociatedFile (if any) whose locale matches the +// `display_names_locale` field of the `ObjectDetectorOptions` used at +// creation time ("en" by default, i.e. English). If none of these are +// available, only the `index` field of the results will be filled. +// (kTfLiteFloat32) +// - scores tensor of size `[num_results]`, each value representing the score +// of the detected object. +// (kTfLiteFloat32) +// - integer num_results as a tensor of size `[1]` +// +// An example of such model can be found at: +// https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1 +// +// A CLI demo tool is available for easily trying out this API, and provides +// example usage. See: +// examples/task/vision/desktop/object_detector_demo.cc +class ObjectDetector : public BaseVisionTaskApi<DetectionResult> { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ObjectDetector from the provided options. A non-default + // OpResolver can be specified in order to support custom Ops or specify a + // subset of built-in Ops. + static tflite::support::StatusOr<std::unique_ptr<ObjectDetector>> + CreateFromOptions( + const ObjectDetectorOptions& options, + std::unique_ptr<tflite::OpResolver> resolver = + absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); + + // Performs actual detection on the provided FrameBuffer. + // + // The FrameBuffer can be of any size and any of the supported formats, i.e. + // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed + // before inference in order to (and in this order): + // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to + // the dimensions of the model input tensor, + // - convert it to the colorspace of the input tensor (i.e. RGB, which is the + // only supported colorspace for now), + // - rotate it according to its `Orientation` so that inference is performed + // on an "upright" image. + // + // IMPORTANT: the returned bounding boxes are expressed in the unrotated input + // frame of reference coordinates system, i.e. in `[0, frame_buffer.width) x + // [0, frame_buffer.height)`, which are the dimensions of the underlying + // `frame_buffer` data before any `Orientation` flag gets applied. + // + // In particular, this implies that the returned bounding boxes may not be + // directly suitable for display if the input image is displayed *with* the + // `Orientation` flag taken into account according to the EXIF specification + // (http://jpegclub.org/exif_orientation.html): it may first need to be + // rotated. This is typically true when consuming camera frames on Android or + // iOS. + // + // For example, if the input `frame_buffer` has its `Orientation` flag set to + // `kLeftBottom` (i.e. the image will be rotated 90° clockwise during + // preprocessing to make it "upright"), then the same 90° clockwise rotation + // needs to be applied to the bounding box for display. + tflite::support::StatusOr<DetectionResult> Detect( + const FrameBuffer& frame_buffer); + + protected: + // Post-processing to transform the raw model outputs into detection results. + tflite::support::StatusOr<DetectionResult> Postprocess( + const std::vector<const TfLiteTensor*>& output_tensors, + const FrameBuffer& frame_buffer, const BoundingBox& roi) override; + + // Performs sanity checks on the provided ObjectDetectorOptions. + static absl::Status SanityCheckOptions(const ObjectDetectorOptions& options); + + // Initializes the ObjectDetector from the provided ObjectDetectorOptions, + // whose ownership is transferred to this object. + absl::Status Init(std::unique_ptr<ObjectDetectorOptions>); + + // Performs pre-initialization actions. + virtual absl::Status PreInit(); + + private: + // Performs sanity checks on the model outputs and extracts their metadata. + absl::Status CheckAndSetOutputs(); + + // Performs sanity checks on the class whitelist/blacklist and forms the class + // index set. + absl::Status CheckAndSetClassIndexSet(); + + // Checks if the class at the provided index is allowed, i.e. whitelisted in + // case a whitelist is provided or not blacklisted if a blacklist is provided. + // Always returns true if no whitelist or blacklist were provided. + bool IsClassIndexAllowed(int class_index); + + // Given a DetectionResult object containing class indices, fills the name and + // display name from the label map. + absl::Status FillResultsFromLabelMap(DetectionResult* result); + + // The options used to build this ObjectDetector. + std::unique_ptr<ObjectDetectorOptions> options_; + + // This is populated by reading the label files from the TFLite Model + // Metadata: if no such files are available, this is left empty and the + // ObjectDetector will only be able to populate the `index` field of the + // detection results `classes` field. + std::vector<LabelMapItem> label_map_; + + // For each pack of 4 coordinates returned by the model, this denotes the + // order in which to get the left, top, right and bottom coordinates. + std::vector<unsigned int> bounding_box_corners_order_; + + // Set of whitelisted or blacklisted class indices. + struct ClassIndexSet { + absl::flat_hash_set<int> values; + bool is_whitelist; + }; + // Whitelisted or blacklisted class indices based on provided options at + // construction time. These are used to filter out results during + // post-processing. + ClassIndexSet class_index_set_; + + // Score threshold. Detections with a confidence below this value are + // discarded. If none is provided via metadata or options, -FLT_MAX is set as + // default value. + float score_threshold_; +}; + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_OBJECT_DETECTOR_H_ diff --git a/tensorflow_lite_support/cc/task/vision/proto/BUILD b/tensorflow_lite_support/cc/task/vision/proto/BUILD new file mode 100644 index 00000000..e294da76 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/BUILD @@ -0,0 +1,208 @@ +load("//tensorflow_lite_support/cc/port:build_defs.bzl", "support_cc_proto_library") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +# Common vision protos. + +proto_library( + name = "bounding_box_proto", + srcs = ["bounding_box.proto"], +) + +support_cc_proto_library( + name = "bounding_box_cc_proto", + srcs = ["bounding_box.proto"], + deps = [ + ":bounding_box_proto", + ], +) + +cc_library( + name = "bounding_box_proto_inc", + hdrs = ["bounding_box_proto_inc.h"], + deps = [":bounding_box_cc_proto"], +) + +proto_library( + name = "class_proto", + srcs = ["class.proto"], +) + +support_cc_proto_library( + name = "class_cc_proto", + srcs = ["class.proto"], + deps = [ + ":class_proto", + ], +) + +cc_library( + name = "class_proto_inc", + hdrs = ["class_proto_inc.h"], + deps = [":class_cc_proto"], +) + +# ObjectDetector protos. + +proto_library( + name = "object_detector_options_proto", + srcs = ["object_detector_options.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto", + ], +) + +support_cc_proto_library( + name = "object_detector_options_cc_proto", + srcs = ["object_detector_options.proto"], + cc_deps = ["//tensorflow_lite_support/cc/task/core/proto:external_file_cc_proto"], + deps = [ + ":object_detector_options_proto", + ], +) + +cc_library( + name = "object_detector_options_proto_inc", + hdrs = ["object_detector_options_proto_inc.h"], + deps = [ + ":object_detector_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + ], +) + +proto_library( + name = "detections_proto", + srcs = ["detections.proto"], + deps = [ + ":bounding_box_proto", + ":class_proto", + ], +) + +support_cc_proto_library( + name = "detections_cc_proto", + srcs = ["detections.proto"], + cc_deps = [ + ":bounding_box_cc_proto", + ":class_cc_proto", + ], + deps = [ + ":detections_proto", + ], +) + +cc_library( + name = "detections_proto_inc", + hdrs = ["detections_proto_inc.h"], + deps = [ + ":bounding_box_proto_inc", + ":class_proto_inc", + ":detections_cc_proto", + ], +) + +# ImageClassifier protos. + +proto_library( + name = "image_classifier_options_proto", + srcs = ["image_classifier_options.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto", + ], +) + +support_cc_proto_library( + name = "image_classifier_options_cc_proto", + srcs = ["image_classifier_options.proto"], + cc_deps = ["//tensorflow_lite_support/cc/task/core/proto:external_file_cc_proto"], + deps = [ + ":image_classifier_options_proto", + ], +) + +cc_library( + name = "image_classifier_options_proto_inc", + hdrs = ["image_classifier_options_proto_inc.h"], + deps = [ + ":image_classifier_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + ], +) + +proto_library( + name = "classifications_proto", + srcs = ["classifications.proto"], + deps = [ + ":class_proto", + ], +) + +support_cc_proto_library( + name = "classifications_cc_proto", + srcs = ["classifications.proto"], + cc_deps = [":class_cc_proto"], + deps = [ + ":classifications_proto", + ], +) + +cc_library( + name = "classifications_proto_inc", + hdrs = ["classifications_proto_inc.h"], + deps = [ + ":class_proto_inc", + ":classifications_cc_proto", + ], +) + +# ImageSegmenter protos. + +proto_library( + name = "image_segmenter_options_proto", + srcs = ["image_segmenter_options.proto"], + deps = [ + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto", + ], +) + +support_cc_proto_library( + name = "image_segmenter_options_cc_proto", + srcs = ["image_segmenter_options.proto"], + cc_deps = ["//tensorflow_lite_support/cc/task/core/proto:external_file_cc_proto"], + deps = [ + ":image_segmenter_options_proto", + ], +) + +cc_library( + name = "image_segmenter_options_proto_inc", + hdrs = ["image_segmenter_options_proto_inc.h"], + deps = [ + ":image_segmenter_options_cc_proto", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + ], +) + +proto_library( + name = "segmentations_proto", + srcs = ["segmentations.proto"], +) + +support_cc_proto_library( + name = "segmentations_cc_proto", + srcs = ["segmentations.proto"], + deps = [ + ":segmentations_proto", + ], +) + +cc_library( + name = "segmentations_proto_inc", + hdrs = ["segmentations_proto_inc.h"], + deps = [":segmentations_cc_proto"], +) diff --git a/tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto b/tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto new file mode 100644 index 00000000..4c2e1302 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +// An integer bounding box, axis aligned. +message BoundingBox { + // The X coordinate of the top-left corner, in pixels. + optional int32 origin_x = 1; + // The Y coordinate of the top-left corner, in pixels. + optional int32 origin_y = 2; + // The width of the bounding box, in pixels. + optional int32 width = 3; + // The height of the bounding box, in pixels. + optional int32 height = 4; +} diff --git a/tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h new file mode 100644 index 00000000..ef84b156 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h @@ -0,0 +1,19 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_BOUNDING_BOX_PROTO_INC_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_BOUNDING_BOX_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box.pb.h" +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_BOUNDING_BOX_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/vision/proto/class.proto b/tensorflow_lite_support/cc/task/vision/proto/class.proto new file mode 100644 index 00000000..19e8ac1d --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/class.proto @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +// A single classification result. +message Class { + // The index of the class in the corresponding label map, usually packed in + // the TFLite Model Metadata [1]. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + optional int32 index = 1; + // The score for this class e.g. (but not necessarily) a probability in [0,1]. + optional float score = 2; + // A human readable name of the class filled from the label map. + optional string display_name = 3; + // An ID for the class, not necessarily human-readable (e.g. a Google + // Knowledge Graph ID [1]), filled from the label map. + // + // [1]: https://developers.google.com/knowledge-graph + optional string class_name = 4; +} diff --git a/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h new file mode 100644 index 00000000..2f9a409d --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/vision/proto/class.pb.h" +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASS_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/vision/proto/classifications.proto b/tensorflow_lite_support/cc/task/vision/proto/classifications.proto new file mode 100644 index 00000000..d3d9c66c --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/classifications.proto @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +import "tensorflow_lite_support/cc/task/vision/proto/class.proto"; + +// List of predicted classes (aka labels) for a given image classifier head. +message Classifications { + // The array of predicted classes, usually sorted by descending scores (e.g. + // from high to low probability). + repeated Class classes = 1; + // The index of the image classifier head these classes refer to. This is + // useful for multi-head models. + optional int32 head_index = 2; +} + +// Contains one set of results per image classifier head. +message ClassificationResult { + repeated Classifications classifications = 1; +} diff --git a/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h new file mode 100644 index 00000000..62a5f117 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h @@ -0,0 +1,22 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" + +#include "tensorflow_lite_support/cc/task/vision/proto/classifications.pb.h" +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_CLASSIFICATIONS_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/vision/proto/detections.proto b/tensorflow_lite_support/cc/task/vision/proto/detections.proto new file mode 100644 index 00000000..b600fc93 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/detections.proto @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +import "tensorflow_lite_support/cc/task/vision/proto/bounding_box.proto"; +import "tensorflow_lite_support/cc/task/vision/proto/class.proto"; + +// A single detected object. +message Detection { + // The bounding box. + // + // IMPORTANT: when using the Task APIs, the bounding box is expressed in the + // unrotated input frame of reference coordinates system, i.e. in `[0, + // frame_buffer.width) x [0, frame_buffer.height)`, which are the dimensions + // of the underlying `frame_buffer` data before any `Orientation` flag gets + // applied. + // + // In particular, this implies that the returned bounding boxes may not be + // directly suitable for display if the input image is displayed *with* the + // `Orientation` flag taken into account according to the EXIF specification + // (http://jpegclub.org/exif_orientation.html): it may first need to be + // rotated. + // + // For example, if the input `frame_buffer` has its `Orientation` flag set to + // `kLeftBottom` (i.e. the image will be rotated 90° clockwise during + // preprocessing to make it "upright"), then the same 90° clockwise rotation + // needs to be applied to the bounding box for display. + optional BoundingBox bounding_box = 2; + // The candidate classes, sorted by descending score. + repeated Class classes = 3; + // Reserved tags. + reserved 1, 4; +} + +// List of detected objects. +message DetectionResult { + repeated Detection detections = 1; +} diff --git a/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h new file mode 100644 index 00000000..2b63cad6 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h @@ -0,0 +1,23 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" + +#include "tensorflow_lite_support/cc/task/vision/proto/detections.pb.h" +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_DETECTIONS_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto b/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto new file mode 100644 index 00000000..24cd85f3 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.proto @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +import "tensorflow_lite_support/cc/task/core/proto/external_file.proto"; + +// Options for setting up an ImageClassifier. +// Next Id: 14 +message ImageClassifierOptions { + // The external model file, as a single standalone TFLite file. If it is + // packed with TFLite Model Metadata [1], those are used to populate e.g. the + // label map, score calibration and recommended score thresholds. Models + // without any such metadata or partial metadata are supported, but may result + // in the image classifier providing degraded functionality; typically, a + // model that doesn't contain any label map won't be able to return any class + // or display names but will be limited to returning class indices. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + optional core.ExternalFile model_file_with_metadata = 10; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + optional string display_names_locale = 11 [default = "en"]; + + // The maximum number of top-scored classification results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + optional int32 max_results = 2 [default = -1]; + + // Score threshold in [0,1), overrides the ones provided in the model metadata + // (if any). Results below this value are rejected. + optional float score_threshold = 3; + + // Optional whitelist of class names. If non-empty, classifications whose + // class name is not in this set will be filtered out. Duplicate or unknown + // class names are ignored. Mutually exclusive with class_name_blacklist. + repeated string class_name_whitelist = 4; + + // Optional blacklist of class names. If non-empty, classifications whose + // class name is in this set will be filtered out. Duplicate or unknown + // class names are ignored. Mutually exclusive with class_name_whitelist. + repeated string class_name_blacklist = 5; + + // The number of threads to be used for TFLite ops that support + // multi-threading when running inference with CPU. + // num_threads should be greater than 0 or equal to -1. Setting num_threads to + // -1 has the effect to let TFLite runtime set the value. + optional int32 num_threads = 13 [default = -1]; + + // Reserved tags. + reserved 1, 6, 7, 8, 9, 12; +} diff --git a/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h new file mode 100644 index 00000000..03dcd759 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h @@ -0,0 +1,22 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" + +#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options.pb.h" +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_CLASSIFIER_OPTIONS_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto b/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto new file mode 100644 index 00000000..3afed86a --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.proto @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +import "tensorflow_lite_support/cc/task/core/proto/external_file.proto"; + +// Options for setting up an ImageSegmenter. +// Next Id: 8 +message ImageSegmenterOptions { + // The external model file, as a single standalone TFLite file. If it is + // packed with TFLite Model Metadata [1], those are used to populate label + // map. Models without any such metadata or partial metadata are supported, + // but may result in the segmenter providing degraded functionality; + // typically, a model that doesn't contain any label map won't be able to + // return any class or display names. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + optional core.ExternalFile model_file_with_metadata = 5; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + optional string display_names_locale = 6 [default = "en"]; + + // Output mask type. This allows specifying the type of post-processing to + // perform on the raw model results (see SegmentationResult proto for more). + enum OutputType { + UNSPECIFIED = 0; + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK = 1; + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK = 2; + } + // Optional output mask type. + optional OutputType output_type = 3 [default = CATEGORY_MASK]; + + // The number of threads to be used for TFLite ops that support + // multi-threading when running inference with CPU. + // num_threads should be greater than 0 or equal to -1. Setting num_threads to + // -1 has the effect to let TFLite runtime set the value. + optional int32 num_threads = 7 [default = -1]; + + // Reserved tags. + reserved 1, 2, 4; +} diff --git a/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h new file mode 100644 index 00000000..aaaecf36 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h @@ -0,0 +1,22 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" + +#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options.pb.h" +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_IMAGE_SEGMENTER_OPTIONS_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto b/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto new file mode 100644 index 00000000..b55e9740 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/object_detector_options.proto @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +import "tensorflow_lite_support/cc/task/core/proto/external_file.proto"; + +// Options for setting up an ObjectDetector. +// Next Id: 8. +message ObjectDetectorOptions { + // The external model file, as a single standalone TFLite file packed with + // TFLite Model Metadata [1]. Those are mandatory, and used to populate e.g. + // the label map and recommended score threshold. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + optional core.ExternalFile model_file_with_metadata = 1; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + optional string display_names_locale = 2 [default = "en"]; + + // The maximum number of top-scored detection results to return. If < 0, all + // available results will be returned. If 0, an invalid argument error is + // returned. Note that models may intrinsically be limited to returning a + // maximum number of results N: if the provided value here is above N, only N + // results will be returned. + optional int32 max_results = 3 [default = -1]; + + // Score threshold to override the one provided in the model metadata (if + // any). Detection results with a score below this value are rejected. + optional float score_threshold = 4; + + // Optional whitelist of class names. If non-empty, detection results whose + // class name is not in this set will be filtered out. Duplicate or unknown + // class names are ignored. Mutually exclusive with class_name_blacklist. + repeated string class_name_whitelist = 5; + + // Optional blacklist of class names. If non-empty, detection results whose + // class name is in this set will be filtered out. Duplicate or unknown + // class names are ignored. Mutually exclusive with class_name_whitelist. + repeated string class_name_blacklist = 6; + + // The number of threads to be used for TFLite ops that support + // multi-threading when running inference with CPU. + // num_threads should be greater than 0 or equal to -1. Setting num_threads to + // -1 has the effect to let TFLite runtime set the value. + optional int32 num_threads = 7 [default = -1]; +} diff --git a/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h new file mode 100644 index 00000000..27898470 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h @@ -0,0 +1,22 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" + +#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options.pb.h" +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_OBJECT_DETECTOR_OPTIONS_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto b/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto new file mode 100644 index 00000000..259bee81 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/segmentations.proto @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package tflite.task.vision; + +// Results of performing image segmentation. +// Note that at the time, a single `Segmentation` element is expected to be +// returned; the field is made repeated for later extension to e.g. instance +// segmentation models, which may return one segmentation per object. +message SegmentationResult { + repeated Segmentation segmentation = 1; +} + +// Next Id: 6 +message Segmentation { + // Confidence mask. This is a flattened 2D-array in row major order. For each + // pixel, the value indicates the prediction confidence usually in the [0, 1] + // range where higher values represent a stronger confidence. Ultimately this + // is model specific, and other range of values might be used. + message ConfidenceMask { + repeated float value = 1 [packed = true]; + } + + // List of confidence masks with respect to the model output depth (this depth + // represents how many classes are supported). Note: some models have a single + // class (e.g. a sky segmentation model) which turns into a single confidence + // mask in this list. + message ConfidenceMasks { + repeated ConfidenceMask confidence_mask = 1; + } + + // IMPORTANT: segmentation masks are not direcly suited for display, in + // particular: + // * they are relative to the unrotated input frame, i.e. *not* taking into + // account the `Orientation` flag of the input FrameBuffer, + // * their dimensions are intrinsic to the model, i.e. *not* dependent on the + // input FrameBuffer dimensions. + // + // Example of such post-processing, assuming: + // * an input FrameBuffer with width=640, height=480, orientation=kLeftBottom + // (i.e. the image will be rotated 90° clockwise during preprocessing to + // make it "upright"), + // * a model outputting masks of size 224x224. + // In order to be directly displayable on top of the input image assumed to + // be displayed *with* the `Orientation` flag taken into account (according to + // the EXIF specification [1]), the masks need to be: + // * re-scaled to 640 x 480, + // * then rotated 90° clockwise. + // + // [1]: http://jpegclub.org/exif_orientation.html + oneof mask_oneof { + // Category mask. This is a flattened 2D-array of size `width` x `height`, + // in row major order. The value of each pixel in this mask represents the + // class to which the pixel belongs. + // See `colored_labels` for instructions on how to get pixel labels and + // display color. + bytes category_mask = 1; + + // One confidence masks of size `width` x `height` for each of the supported + // classes. The value of each pixel in these masks represents the confidence + // score for this particular class. + // See `colored_labels` for instructions on how to get pixel labels and + // display color. + ConfidenceMasks confidence_masks = 4; + } + // The width of the mask. This is an intrinsic parameter of the model being + // used, and does not depend on the input image dimensions. + optional int32 width = 2; + // The height of the mask. This is an intrinsic parameter of the model being + // used, and does not depend on the input image dimensions. + optional int32 height = 3; + + // Defines a label associated with an RGB color, for display purposes. + message ColoredLabel { + // The RGB color components for the label, in the [0, 255] range. + optional uint32 r = 1; + optional uint32 g = 2; + optional uint32 b = 3; + // The class name, as provided in the label map packed in the TFLite Model + // Metadata. + optional string class_name = 4; + // The display name, as provided in the label map (if available) packed in + // the TFLite Model Metadata. See `display_names_locale` field in + // ImageSegmenterOptions. + optional string display_name = 5; + } + + // The list of colored labels for all the supported categories. Depending on + // which is present, this list is in 1:1 correspondence with: + // * `category_mask` pixel values, i.e. a pixel with value `i` is + // associated with `colored_labels[i]`, + // * `confidence_masks` indices, i.e. `confidence_masks[i]` is associated with + // `colored_labels[i]`. + repeated ColoredLabel colored_labels = 5; +} diff --git a/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h b/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h new file mode 100644 index 00000000..cfc96e69 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h @@ -0,0 +1,19 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_ + +#include "tensorflow_lite_support/cc/task/vision/proto/segmentations.pb.h" +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_PROTO_SEGMENTATIONS_PROTO_INC_H_ diff --git a/tensorflow_lite_support/cc/task/vision/utils/BUILD b/tensorflow_lite_support/cc/task/vision/utils/BUILD new file mode 100644 index 00000000..89951451 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/BUILD @@ -0,0 +1,109 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:users", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "score_calibration", + srcs = ["score_calibration.cc"], + hdrs = ["score_calibration.h"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision/core:label_map_item", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "frame_buffer_common_utils", + srcs = [ + "frame_buffer_common_utils.cc", + ], + hdrs = [ + "frame_buffer_common_utils.h", + "frame_buffer_utils_interface.h", + ], + deps = [ + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "frame_buffer_utils", + srcs = [ + "frame_buffer_utils.cc", + ], + hdrs = [ + "frame_buffer_utils.h", + ], + deps = [ + ":frame_buffer_common_utils", + ":libyuv_frame_buffer_utils", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + "@org_tensorflow//tensorflow/lite/kernels:op_macros", + "@org_tensorflow//tensorflow/lite/kernels/internal:compatibility", + ], +) + +cc_library( + name = "libyuv_frame_buffer_utils", + srcs = ["libyuv_frame_buffer_utils.cc"], + hdrs = ["libyuv_frame_buffer_utils.h"], + deps = [ + ":frame_buffer_common_utils", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@libyuv", + ], +) + +cc_library( + name = "image_tensor_specs", + srcs = ["image_tensor_specs.cc"], + hdrs = ["image_tensor_specs.h"], + deps = [ + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:tflite_engine", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + ], +) diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc new file mode 100644 index 00000000..fa9b05f5 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc @@ -0,0 +1,428 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" + +#include <string> +#include <vector> + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" + +namespace tflite { +namespace task { +namespace vision { +namespace { + +using ::tflite::support::StatusOr; + +constexpr int kRgbaChannels = 4; +constexpr int kRgbChannels = 3; +constexpr int kGrayChannel = 1; + +// Creates a FrameBuffer from raw NV12 buffer and passing arguments. +std::unique_ptr<FrameBuffer> CreateFromNV12RawBuffer( + const uint8* input, FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation, const absl::Time timestamp) { + const std::vector<FrameBuffer::Plane> planes_nv12 = { + {input, /*stride=*/{dimension.width, kGrayChannel}}, + {input + dimension.Size(), /*stride=*/{dimension.width, 2}}}; + return FrameBuffer::Create(planes_nv12, dimension, FrameBuffer::Format::kNV12, + orientation, timestamp); +} + +// Creates a FrameBuffer from raw NV21 buffer and passing arguments. +std::unique_ptr<FrameBuffer> CreateFromNV21RawBuffer( + const uint8* input, FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation, const absl::Time timestamp) { + FrameBuffer::Plane input_plane = {/*buffer=*/input, + /*stride=*/{dimension.width, kGrayChannel}}; + return FrameBuffer::Create({input_plane}, dimension, + FrameBuffer::Format::kNV21, orientation, + timestamp); +} + +// Indicates whether the given buffers have the same dimensions. +bool AreBufferDimsEqual(const FrameBuffer& buffer1, + const FrameBuffer& buffer2) { + return buffer1.dimension() == buffer2.dimension(); +} + +// Indicates whether the given buffers formats are compatible. Same formats are +// compatible and all YUV family formats (e.g. NV21, NV12, YV12, YV21, etc) are +// compatible. +bool AreBufferFormatsCompatible(const FrameBuffer& buffer1, + const FrameBuffer& buffer2) { + switch (buffer1.format()) { + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + return (buffer2.format() == FrameBuffer::Format::kRGBA || + buffer2.format() == FrameBuffer::Format::kRGB); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return (buffer2.format() == FrameBuffer::Format::kNV12 || + buffer2.format() == FrameBuffer::Format::kNV21 || + buffer2.format() == FrameBuffer::Format::kYV12 || + buffer2.format() == FrameBuffer::Format::kYV21); + case FrameBuffer::Format::kGRAY: + default: + return buffer1.format() == buffer2.format(); + } +} + +} // namespace + +// Miscellaneous Methods +// ----------------------------------------------------------------- +int GetFrameBufferByteSize(FrameBuffer::Dimension dimension, + FrameBuffer::Format format) { + switch (format) { + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return /*y plane*/ dimension.Size() + + /*uv plane*/ ((static_cast<float>(dimension.width + 1) / 2) * + (static_cast<float>(dimension.height + 1) / 2) * 2); + case FrameBuffer::Format::kRGB: + return dimension.Size() * 3; + case FrameBuffer::Format::kRGBA: + return dimension.Size() * 4; + case FrameBuffer::Format::kGRAY: + return dimension.Size(); + default: + return 0; + } +} + +StatusOr<int> GetPixelStrides(FrameBuffer::Format format) { + switch (format) { + case FrameBuffer::Format::kGRAY: + return kGrayPixelBytes; + case FrameBuffer::Format::kRGB: + return kRgbPixelBytes; + case FrameBuffer::Format::kRGBA: + return kRgbaPixelBytes; + default: + return absl::InvalidArgumentError(absl::StrFormat( + "GetPixelStrides does not support format: %i.", format)); + } +} + +StatusOr<const uint8*> GetUvRawBuffer(const FrameBuffer& buffer) { + if (buffer.format() != FrameBuffer::Format::kNV12 && + buffer.format() != FrameBuffer::Format::kNV21) { + return absl::InvalidArgumentError( + "Only support getting biplanar UV buffer from NV12/NV21 frame buffer."); + } + ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + const uint8* uv_buffer = buffer.format() == FrameBuffer::Format::kNV12 + ? yuv_data.u_buffer + : yuv_data.v_buffer; + return uv_buffer; +} + +StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension( + FrameBuffer::Dimension dimension, FrameBuffer::Format format) { + if (dimension.width <= 0 || dimension.height <= 0) { + return absl::InvalidArgumentError( + absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width, + dimension.height)); + } + switch (format) { + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return FrameBuffer::Dimension{(dimension.width + 1) / 2, + (dimension.height + 1) / 2}; + default: + return absl::InvalidArgumentError( + absl::StrFormat("Input format is not YUV-like: %i.", format)); + } +} + +FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1) { + return {x1 - x0 + 1, y1 - y0 + 1}; +} + +// Validation Methods +// ----------------------------------------------------------------- + +absl::Status ValidateBufferPlaneMetadata(const FrameBuffer& buffer) { + if (buffer.plane_count() < 1) { + return absl::InvalidArgumentError( + "There must be at least 1 plane specified."); + } + + for (int i = 0; i < buffer.plane_count(); i++) { + if (buffer.plane(i).stride.row_stride_bytes == 0 || + buffer.plane(i).stride.pixel_stride_bytes == 0) { + return absl::InvalidArgumentError("Invalid stride information."); + } + } + + return absl::OkStatus(); +} + +absl::Status ValidateBufferFormat(const FrameBuffer& buffer) { + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + case FrameBuffer::Format::kRGB: + case FrameBuffer::Format::kRGBA: + if (buffer.plane_count() == 1) return absl::OkStatus(); + return absl::InvalidArgumentError( + "Plane count must be 1 for grayscale and RGB[a] buffers."); + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kYV21: + case FrameBuffer::Format::kYV12: + return absl::OkStatus(); + default: + return absl::InternalError( + absl::StrFormat("Unsupported buffer format: %i.", buffer.format())); + } +} + +absl::Status ValidateBufferFormats(const FrameBuffer& buffer1, + const FrameBuffer& buffer2) { + RETURN_IF_ERROR(ValidateBufferFormat(buffer1)); + RETURN_IF_ERROR(ValidateBufferFormat(buffer2)); + return absl::OkStatus(); +} + +absl::Status ValidateResizeBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer) { + bool valid_format = false; + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + case FrameBuffer::Format::kRGB: + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + valid_format = (buffer.format() == output_buffer.format()); + break; + case FrameBuffer::Format::kRGBA: + valid_format = (output_buffer.format() == FrameBuffer::Format::kRGBA || + output_buffer.format() == FrameBuffer::Format::kRGB); + break; + default: + return absl::InternalError( + absl::StrFormat("Unsupported buffer format: %i.", buffer.format())); + } + if (!valid_format) { + return absl::InvalidArgumentError( + "Input and output buffer formats must match."); + } + return ValidateBufferFormats(buffer, output_buffer); +} + +absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer, + int angle_deg) { + if (!AreBufferFormatsCompatible(buffer, output_buffer)) { + return absl::InvalidArgumentError( + "Input and output buffer formats must match."); + } + + const bool is_dimension_change = (angle_deg / 90) % 2 == 1; + const bool are_dimensions_rotated = + (buffer.dimension().width == output_buffer.dimension().height) && + (buffer.dimension().height == output_buffer.dimension().width); + const bool are_dimensions_equal = + buffer.dimension() == output_buffer.dimension(); + + if (angle_deg >= 360 || angle_deg <= 0 || angle_deg % 90 != 0) { + return absl::InvalidArgumentError( + "Rotation angle must be between 0 and 360, in multiples of 90 " + "degrees."); + } else if ((is_dimension_change && !are_dimensions_rotated) || + (!is_dimension_change && !are_dimensions_equal)) { + return absl::InvalidArgumentError( + "Output buffer has invalid dimensions for rotation."); + } + return absl::OkStatus(); +} + +absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer, int x0, + int y0, int x1, int y1) { + if (!AreBufferFormatsCompatible(buffer, output_buffer)) { + return absl::InvalidArgumentError( + "Input and output buffer formats must match."); + } + + bool is_buffer_size_valid = + ((x1 < buffer.dimension().width) && y1 < buffer.dimension().height); + bool are_points_valid = (x0 >= 0) && (y0 >= 0) && (x1 >= x0) && (y1 >= y0); + + if (!is_buffer_size_valid || !are_points_valid) { + return absl::InvalidArgumentError("Invalid crop coordinates."); + } + return absl::OkStatus(); +} + +absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer) { + if (!AreBufferFormatsCompatible(buffer, output_buffer)) { + return absl::InvalidArgumentError( + "Input and output buffer formats must match."); + } + return AreBufferDimsEqual(buffer, output_buffer) + ? absl::OkStatus() + : absl::InvalidArgumentError( + "Input and output buffers must have the same dimensions."); +} + +absl::Status ValidateConvertFormats(FrameBuffer::Format from_format, + FrameBuffer::Format to_format) { + if (from_format == to_format) { + return absl::InvalidArgumentError("Formats must be different."); + } + + switch (from_format) { + case FrameBuffer::Format::kGRAY: + return absl::InvalidArgumentError( + "Grayscale format does not convert to other formats."); + case FrameBuffer::Format::kRGB: + if (to_format == FrameBuffer::Format::kRGBA) { + return absl::InvalidArgumentError( + "RGB format does not convert to RGBA"); + } + return absl::OkStatus(); + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return absl::OkStatus(); + default: + return absl::InternalError( + absl::StrFormat("Unsupported buffer format: %i.", from_format)); + } +} + +// Creation Methods +// ----------------------------------------------------------------- + +// Creates a FrameBuffer from raw RGBA buffer and passing arguments. +std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer( + const uint8* input, FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation, const absl::Time timestamp) { + FrameBuffer::Plane input_plane = { + /*buffer=*/input, + /*stride=*/{dimension.width * kRgbaChannels, kRgbaChannels}}; + return FrameBuffer::Create({input_plane}, dimension, + FrameBuffer::Format::kRGBA, orientation, + timestamp); +} + +// Creates a FrameBuffer from raw RGB buffer and passing arguments. +std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer( + const uint8* input, FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation, const absl::Time timestamp) { + FrameBuffer::Plane input_plane = { + /*buffer=*/input, + /*stride=*/{dimension.width * kRgbChannels, kRgbChannels}}; + return FrameBuffer::Create({input_plane}, dimension, + FrameBuffer::Format::kRGB, orientation, timestamp); +} + +// Creates a FrameBuffer from raw grayscale buffer and passing arguments. +std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer( + const uint8* input, FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation, const absl::Time timestamp) { + FrameBuffer::Plane input_plane = {/*buffer=*/input, + /*stride=*/{dimension.width, kGrayChannel}}; + return FrameBuffer::Create({input_plane}, dimension, + FrameBuffer::Format::kGRAY, orientation, + timestamp); +} + +// Creates a FrameBuffer from raw YUV buffer and passing arguments. +StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer( + const uint8* y_plane, const uint8* u_plane, const uint8* v_plane, + FrameBuffer::Format format, FrameBuffer::Dimension dimension, + int row_stride_y, int row_stride_uv, int pixel_stride_uv, + FrameBuffer::Orientation orientation, const absl::Time timestamp) { + const int pixel_stride_y = 1; + std::vector<FrameBuffer::Plane> planes; + if (format == FrameBuffer::Format::kNV21 || + format == FrameBuffer::Format::kYV12) { + planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}}, + {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}, + {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}}; + } else if (format == FrameBuffer::Format::kNV12 || + format == FrameBuffer::Format::kYV21) { + planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}}, + {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}, + {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}}; + } else { + return absl::InvalidArgumentError( + absl::StrFormat("Input format is not YUV-like: %i.", format)); + } + return FrameBuffer::Create(planes, dimension, format, orientation, timestamp); +} + +StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer( + const uint8* buffer, FrameBuffer::Dimension dimension, + const FrameBuffer::Format target_format, + FrameBuffer::Orientation orientation, absl::Time timestamp) { + switch (target_format) { + case FrameBuffer::Format::kNV12: + return CreateFromNV12RawBuffer(buffer, dimension, orientation, timestamp); + case FrameBuffer::Format::kNV21: + return CreateFromNV21RawBuffer(buffer, dimension, orientation, timestamp); + case FrameBuffer::Format::kYV12: { + ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension, + GetUvPlaneDimension(dimension, target_format)); + return CreateFromYuvRawBuffer( + /*y_plane=*/buffer, + /*u_plane=*/buffer + dimension.Size() + uv_dimension.Size(), + /*v_plane=*/buffer + dimension.Size(), target_format, dimension, + /*row_stride_y=*/dimension.width, uv_dimension.width, + /*pixel_stride_uv=*/1, orientation, timestamp); + } + case FrameBuffer::Format::kYV21: { + ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension, + GetUvPlaneDimension(dimension, target_format)); + return CreateFromYuvRawBuffer( + /*y_plane=*/buffer, /*u_plane=*/buffer + dimension.Size(), + /*v_plane=*/buffer + dimension.Size() + uv_dimension.Size(), + target_format, dimension, /*row_stride_y=*/dimension.width, + uv_dimension.width, + /*pixel_stride_uv=*/1, orientation, timestamp); + } + case FrameBuffer::Format::kRGBA: + return CreateFromRgbaRawBuffer(buffer, dimension, orientation, timestamp); + case FrameBuffer::Format::kRGB: + return CreateFromRgbRawBuffer(buffer, dimension, orientation, timestamp); + case FrameBuffer::Format::kGRAY: + return CreateFromGrayRawBuffer(buffer, dimension, orientation, timestamp); + default: + + return absl::InternalError( + absl::StrFormat("Unsupported buffer format: %i.", target_format)); + } +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h new file mode 100644 index 00000000..e250d154 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h @@ -0,0 +1,143 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_COMMON_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_COMMON_UTILS_H_ + +#include <memory> + +#include "absl/status/status.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" + +namespace tflite { +namespace task { +namespace vision { + +constexpr int kRgbaPixelBytes = 4, kRgbPixelBytes = 3, kGrayPixelBytes = 1; + +// Miscellaneous Methods +// ----------------------------------------------------------------- + +// Returns the frame buffer size in bytes based on the input format and +// dimensions. GRAY, YV12/YV21 are in the planar formats, NV12/NV21 are in the +// semi-planar formats with the interleaved UV planes. RGB/RGBA are in the +// interleaved format. +int GetFrameBufferByteSize(FrameBuffer::Dimension dimension, + FrameBuffer::Format format); + +// Returns pixel stride info for kGRAY, kRGB, kRGBA formats. +tflite::support::StatusOr<int> GetPixelStrides(FrameBuffer::Format format); + +// Returns the biplanar UV raw buffer for NV12/NV21 frame buffer. +tflite::support::StatusOr<const uint8*> GetUvRawBuffer( + const FrameBuffer& buffer); + +// Returns U or V plane dimension with the given buffer `dimension` and +// `format`. Only supports NV12/NV21/YV12/YV21 formats. Returns +// InvalidArgumentError if 'dimension' is invalid or 'format' is other than the +// supported formats. This method assums the UV plane share the same dimension, +// especially for the YV12 / YV21 formats. +tflite::support::StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension( + FrameBuffer::Dimension dimension, FrameBuffer::Format format); + +// Returns crop dimension based on crop start and end points. +FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1); + +// Validation Methods +// ----------------------------------------------------------------- + +// Validates that the given buffer has the correct metadata. Returns error +// state when any buffer has missing stride info. +absl::Status ValidateBufferPlaneMetadata(const FrameBuffer& buffer); + +// Validates that the given buffer has the correct format for its configuration. +absl::Status ValidateBufferFormat(const FrameBuffer& buffer); + +// Validates that the given buffers have the correct format for their +// configuration. +absl::Status ValidateBufferFormats(const FrameBuffer& buffer1, + const FrameBuffer& buffer2); + +// Validates the given inputs for resizing `buffer`. +absl::Status ValidateResizeBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer); + +// Validates the given inputs for rotating `buffer`. +absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer, + int angle_deg); + +// Validates the given inputs for cropping `buffer`. +// +// (x0, y0) represents the top-left point of the buffer. +// (x1, y1) represents the bottom-right point of the buffer. +absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer, int x0, + int y0, int x1, int y1); + +// Validates the given inputs for flipping `buffer` horizontally or vertically. +absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer); + +// Validates that `from_format` can be converted to `to_format`. +// +// The given formats must not be equal. +absl::Status ValidateConvertFormats(FrameBuffer::Format from_format, + FrameBuffer::Format to_format); + +// Creation Methods +// ----------------------------------------------------------------- + +// Creates a FrameBuffer from raw RGBA buffer and passing arguments. +std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer( + const uint8* input, FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now()); + +// Creates a FrameBuffer from raw RGB buffer and passing arguments. +std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer( + const uint8* input, FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now()); + +// Creates a FrameBuffer from raw grayscale buffer and passing arguments. +std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer( + const uint8* input, FrameBuffer::Dimension dimension, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now()); + +// Creates a FrameBuffer from raw YUV buffer and passing arguments. +tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer( + const uint8* y_plane, const uint8* u_plane, const uint8* v_plane, + FrameBuffer::Format format, FrameBuffer::Dimension dimension, + int row_stride_y, int row_stride_uv, int pixel_stride_uv, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now()); + +// Creates an instance of FrameBuffer from raw buffer and passing arguments. +tflite::support::StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer( + const uint8* buffer, FrameBuffer::Dimension dimension, + FrameBuffer::Format target_format, + FrameBuffer::Orientation orientation = FrameBuffer::Orientation::kTopLeft, + absl::Time timestamp = absl::Now()); + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_COMMON_UTILS_H_ diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc new file mode 100644 index 00000000..9b9d830e --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.cc @@ -0,0 +1,619 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" + +#include <algorithm> +#include <iterator> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h" + +namespace tflite { +namespace task { +namespace vision { + +namespace { + +// Exif grouping to help determine rotation and flipping neededs between +// different orientations. +constexpr int kExifGroup[] = {1, 6, 3, 8, 2, 5, 4, 7}; +// Exif group size. +constexpr int kExifGroupSize = 4; + +// Returns orientation position in Exif group. +static int GetOrientationIndex(FrameBuffer::Orientation orientation) { + const int* index = std::find(kExifGroup, kExifGroup + kExifGroupSize * 2, + static_cast<int>(orientation)); + if (index < kExifGroup + kExifGroupSize * 2) { + return std::distance(kExifGroup, index); + } + return -1; +} + +// Returns the coordinates of `box` respect to its containing image (dimension +// defined by `width` and `height`) orientation change. The `angle` is defined +// in counterclockwise degree in one of the values [0, 90, 180, 270]. +// +// The below diagrams illustrate calling this method with 90 CCW degree. +// +// The [1]-[4] denotes image corners and 1 - 4 denotes the box corners. The * +// denotes the current origin. +// +// width +// [1]*----------------[2] +// | | +// | | +// | 1*-----2 | height +// | | box | | +// | 3------4 | +// [3]-----------------[4] +// +// When rotate the above image by 90 CCW degree, the origin also changes +// respects to its containing coordinate space. +// +// height +// [2]*----------[4] +// | | +// | 2*---4 | +// | |box | | +// | | | | width +// | 1----3 | +// | | +// | | +// | | +// [1]-----------[3] +// +// The origin is always defined by the top left corner. After rotation, the +// box origin changed from 1 to 2. +// The new box origin is (x:box.origin_y, y:width - (box.origin_x + box.width). +// The new box dimension is (w: box.height, h: box.width). +// +static BoundingBox RotateBoundingBox(const BoundingBox& box, int angle, + FrameBuffer::Dimension frame_dimension) { + int rx = box.origin_x(), ry = box.origin_y(), rw = box.width(), + rh = box.height(); + const int box_right_bound = + frame_dimension.width - (box.origin_x() + box.width()); + const int box_bottom_bound = + frame_dimension.height - (box.origin_y() + box.height()); + switch (angle) { + case 90: + rx = box.origin_y(); + ry = box_right_bound; + using std::swap; + swap(rw, rh); + break; + case 180: + rx = box_right_bound; + ry = box_bottom_bound; + break; + case 270: + rx = box_bottom_bound; + ry = box.origin_x(); + using std::swap; + swap(rw, rh); + break; + } + BoundingBox result; + result.set_origin_x(rx); + result.set_origin_y(ry); + result.set_width(rw); + result.set_height(rh); + return result; +} + +// Returns the input coordinates with respect to its containing image (dimension +// defined by `width` and `height`) orientation change. The `angle` is defined +// in counterclockwise degree in one of the values [0, 90, 180, 270]. +// +// See `RotateBoundingBox` above for more details. +static void RotateCoordinates(int from_x, int from_y, int angle, + const FrameBuffer::Dimension& frame_dimension, + int* to_x, int* to_y) { + switch (angle) { + case 0: + *to_x = from_x; + *to_y = from_y; + break; + case 90: + *to_x = from_y; + *to_y = frame_dimension.width - from_x - 1; + break; + case 180: + *to_x = frame_dimension.width - from_x - 1; + *to_y = frame_dimension.height - from_y - 1; + break; + case 270: + *to_x = frame_dimension.height - from_y - 1; + *to_y = from_x; + break; + } +} + +} // namespace + +int GetBufferByteSize(FrameBuffer::Dimension dimension, + FrameBuffer::Format format) { + return GetFrameBufferByteSize(dimension, format); +} + +FrameBufferUtils::FrameBufferUtils(ProcessEngine engine) { + switch (engine) { + case ProcessEngine::kLibyuv: + utils_ = absl::make_unique<LibyuvFrameBufferUtils>(); + break; + default: + TF_LITE_FATAL( + absl::StrFormat("Unexpected ProcessEngine: %d.", engine).c_str()); + } +} + +BoundingBox OrientBoundingBox(const BoundingBox& from_box, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, + FrameBuffer::Dimension from_dimension) { + BoundingBox to_box = from_box; + OrientParams params = GetOrientParams(from_orientation, to_orientation); + // First, rotate if needed. + if (params.rotation_angle_deg > 0) { + to_box = + RotateBoundingBox(to_box, params.rotation_angle_deg, from_dimension); + } + // Then perform horizontal or vertical flip if needed. + FrameBuffer::Dimension to_dimension = from_dimension; + if (params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270) { + to_dimension.Swap(); + } + if (params.flip == OrientParams::FlipType::kVertical) { + to_box.set_origin_y(to_dimension.height - + (to_box.origin_y() + to_box.height())); + } + if (params.flip == OrientParams::FlipType::kHorizontal) { + to_box.set_origin_x(to_dimension.width - + (to_box.origin_x() + to_box.width())); + } + return to_box; +} + +BoundingBox OrientAndDenormalizeBoundingBox( + float from_left, float from_top, float from_right, float from_bottom, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, + FrameBuffer::Dimension from_dimension) { + BoundingBox from_box; + from_box.set_origin_x(from_left * from_dimension.width); + from_box.set_origin_y(from_top * from_dimension.height); + from_box.set_width(round(abs(from_right - from_left) * from_dimension.width)); + from_box.set_height( + round(abs(from_bottom - from_top) * from_dimension.height)); + BoundingBox to_box = OrientBoundingBox(from_box, from_orientation, + to_orientation, from_dimension); + return to_box; +} + +void OrientCoordinates(int from_x, int from_y, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, + FrameBuffer::Dimension from_dimension, int* to_x, + int* to_y) { + *to_x = from_x; + *to_y = from_y; + OrientParams params = GetOrientParams(from_orientation, to_orientation); + // First, rotate if needed. + if (params.rotation_angle_deg > 0) { + RotateCoordinates(from_x, from_y, params.rotation_angle_deg, from_dimension, + to_x, to_y); + } + // Then perform horizontal or vertical flip if needed. + FrameBuffer::Dimension to_dimension = from_dimension; + if (params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270) { + to_dimension.Swap(); + } + if (params.flip == OrientParams::FlipType::kVertical) { + *to_y = to_dimension.height - *to_y - 1; + } + if (params.flip == OrientParams::FlipType::kHorizontal) { + *to_x = to_dimension.width - *to_x - 1; + } +} + +// The algorithm is based on grouping orientations into two groups with specific +// order. The two groups of orientation are {1, 6, 3, 8} and {2, 5, 4, 7}. See +// image (https://www.impulseadventure.com/photo/images/orient_flag.gif) for +// the visual grouping illustration. +// +// Each group contains elements can be transformed into one another by rotation. +// The elements order within a group is important such that the distance between +// the elements indicates the multiples of 90 degree needed to orient from one +// element to another. For example, to orient element 1 to element 6, a 90 +// degree CCW rotation is needed. +// +// The corresponding order between the two groups is important such that the +// even index defined the need for horizontal flipping and the odd index defined +// the need for vertical flipping. For example, to orient element 1 to element 2 +// (even index) a horizontal flipping is needed. +// +// The implementation determines the group and element index of from and to +// orientations. Based on the group and element index information, the above +// characteristic is used to calculate the rotation angle and the need for +// horizontal or vertical flipping. +OrientParams GetOrientParams(FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation) { + int from_index = GetOrientationIndex(from_orientation); + int to_index = GetOrientationIndex(to_orientation); + int angle = 0; + absl::optional<OrientParams::FlipType> flip; + + TFLITE_DCHECK(from_index > -1 && to_index > -1); + + if ((from_index < kExifGroupSize && to_index < kExifGroupSize) || + (from_index >= kExifGroupSize && to_index >= kExifGroupSize)) { + // Only needs rotation. + + // The orientations' position differences translates to how many + // multiple of 90 degrees it needs for conversion. The position difference + // calculation within a group is circular. + angle = (kExifGroupSize - (from_index - to_index)) % kExifGroupSize * 90; + } else { + // Needs rotation and flipping. + int from_index_mod = from_index % kExifGroupSize; + int to_index_mod = to_index % kExifGroupSize; + angle = (kExifGroupSize - (from_index_mod - to_index_mod)) % + kExifGroupSize * 90; + if (to_index_mod % 2 == 1) { + flip = OrientParams::FlipType::kVertical; + } else { + flip = OrientParams::FlipType::kHorizontal; + } + } + return {angle, flip}; +} + +bool RequireDimensionSwap(FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation) { + OrientParams params = GetOrientParams(from_orientation, to_orientation); + return params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270; +} + +absl::Status FrameBufferUtils::Crop(const FrameBuffer& buffer, int x0, int y0, + int x1, int y1, + FrameBuffer* output_buffer) { + TFLITE_DCHECK(utils_ != nullptr); + return utils_->Crop(buffer, x0, y0, x1, y1, output_buffer); +} + +FrameBuffer::Dimension FrameBufferUtils::GetSize( + const FrameBuffer& buffer, const FrameBufferOperation& operation) { + FrameBuffer::Dimension dimension = buffer.dimension(); + if (absl::holds_alternative<OrientOperation>(operation)) { + OrientParams params = + GetOrientParams(buffer.orientation(), + absl::get<OrientOperation>(operation).to_orientation); + if (params.rotation_angle_deg == 90 || params.rotation_angle_deg == 270) { + dimension.Swap(); + } + } else if (absl::holds_alternative<CropResizeOperation>(operation)) { + const auto& crop_resize = absl::get<CropResizeOperation>(operation); + dimension = crop_resize.resize_dimension; + } + return dimension; +} + +std::vector<FrameBuffer::Plane> FrameBufferUtils::GetPlanes( + const uint8* buffer, FrameBuffer::Dimension dimension, + FrameBuffer::Format format) { + std::vector<FrameBuffer::Plane> planes; + switch (format) { + case FrameBuffer::Format::kGRAY: + planes.push_back({/*buffer=*/buffer, + /*stride=*/{/*row_stride_bytes=*/dimension.width * 1, + /*pixel_stride_bytes=*/1}}); + break; + case FrameBuffer::Format::kRGB: + planes.push_back({/*buffer=*/buffer, + /*stride=*/{/*row_stride_bytes=*/dimension.width * 3, + /*pixel_stride_bytes=*/3}}); + break; + case FrameBuffer::Format::kRGBA: + planes.push_back({/*buffer=*/buffer, + /*stride=*/{/*row_stride_bytes=*/dimension.width * 4, + /*pixel_stride_bytes=*/4}}); + break; + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kNV12: { + planes.push_back( + {buffer, /*stride=*/{/*row_stride_bytes=*/dimension.width, + /*pixel_stride_bytes=*/1}}); + planes.push_back({buffer + (dimension.width * dimension.height), + /*stride=*/{/*row_stride_bytes=*/dimension.width, + /*pixel_stride_bytes=*/2}}); + } break; + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: { + const int y_buffer_size = dimension.width * dimension.height; + const int uv_row_stride = (dimension.width + 1) / 2; + const int uv_buffer_size = uv_row_stride * (dimension.height + 1) / 2; + planes.push_back( + {buffer, /*stride=*/{/*row_stride_bytes=*/dimension.width, + /*pixel_stride_bytes=*/1}}); + planes.push_back( + {buffer + y_buffer_size, /*stride=*/{ + /*row_stride_bytes=*/uv_row_stride, /*pixel_stride_bytes=*/1}}); + planes.push_back( + {buffer + y_buffer_size + uv_buffer_size, /*stride=*/{ + /*row_stride_bytes=*/uv_row_stride, /*pixel_stride_bytes=*/1}}); + } break; + default: + break; + } + return planes; +} + +FrameBuffer::Orientation FrameBufferUtils::GetOrientation( + const FrameBuffer& buffer, const FrameBufferOperation& operation) { + if (absl::holds_alternative<OrientOperation>(operation)) { + return absl::get<OrientOperation>(operation).to_orientation; + } + return buffer.orientation(); +} + +FrameBuffer::Format FrameBufferUtils::GetFormat( + const FrameBuffer& buffer, const FrameBufferOperation& operation) { + if (absl::holds_alternative<ConvertOperation>(operation)) { + return absl::get<ConvertOperation>(operation).to_format; + } + return buffer.format(); +} + +absl::Status FrameBufferUtils::Execute(const FrameBuffer& buffer, + const FrameBufferOperation& operation, + FrameBuffer* output_buffer) { + if (absl::holds_alternative<CropResizeOperation>(operation)) { + const auto& params = absl::get<CropResizeOperation>(operation); + RETURN_IF_ERROR( + Crop(buffer, params.crop_origin_x, params.crop_origin_y, + (params.crop_dimension.width + params.crop_origin_x - 1), + (params.crop_dimension.height + params.crop_origin_y - 1), + output_buffer)); + } else if (absl::holds_alternative<ConvertOperation>(operation)) { + RETURN_IF_ERROR(Convert(buffer, output_buffer)); + } else if (absl::holds_alternative<OrientOperation>(operation)) { + RETURN_IF_ERROR(Orient(buffer, output_buffer)); + } else { + return absl::UnimplementedError(absl::StrFormat( + "FrameBufferOperation %i is not supported.", operation.index())); + } + return absl::OkStatus(); +} + +absl::Status FrameBufferUtils::Resize(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + TFLITE_DCHECK(utils_ != nullptr); + return utils_->Resize(buffer, output_buffer); +} + +absl::Status FrameBufferUtils::Rotate(const FrameBuffer& buffer, + RotationDegree rotation, + FrameBuffer* output_buffer) { + TFLITE_DCHECK(utils_ != nullptr); + return utils_->Rotate(buffer, 90 * static_cast<int>(rotation), output_buffer); +} + +absl::Status FrameBufferUtils::FlipHorizontally(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + TFLITE_DCHECK(utils_ != nullptr); + return utils_->FlipHorizontally(buffer, output_buffer); +} + +absl::Status FrameBufferUtils::FlipVertically(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + TFLITE_DCHECK(utils_ != nullptr); + return utils_->FlipVertically(buffer, output_buffer); +} + +absl::Status FrameBufferUtils::Convert(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + TFLITE_DCHECK(utils_ != nullptr); + return utils_->Convert(buffer, output_buffer); +} + +absl::Status FrameBufferUtils::Orient(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + TFLITE_DCHECK(utils_ != nullptr); + + OrientParams params = + GetOrientParams(buffer.orientation(), output_buffer->orientation()); + if (params.rotation_angle_deg == 0 && !params.flip.has_value()) { + // If no rotation or flip is needed, we will copy the buffer to + // output_buffer. + return utils_->Resize(buffer, output_buffer); + } + + if (params.rotation_angle_deg == 0) { + // Only perform flip operation. + switch (*params.flip) { + case OrientParams::FlipType::kHorizontal: + return utils_->FlipHorizontally(buffer, output_buffer); + case OrientParams::FlipType::kVertical: + return utils_->FlipVertically(buffer, output_buffer); + } + } + + if (!params.flip.has_value()) { + // Only perform rotation operation. + return utils_->Rotate(buffer, params.rotation_angle_deg, output_buffer); + } + + // Perform rotation and flip operations. + // Create a temporary buffer to hold the rotation result. + auto tmp_buffer = absl::make_unique<uint8[]>( + GetBufferByteSize(output_buffer->dimension(), output_buffer->format())); + auto tmp_frame_buffer = FrameBuffer::Create( + GetPlanes(tmp_buffer.get(), output_buffer->dimension(), + output_buffer->format()), + output_buffer->dimension(), buffer.format(), buffer.orientation()); + + RETURN_IF_ERROR(utils_->Rotate(buffer, params.rotation_angle_deg, + tmp_frame_buffer.get())); + if (params.flip == OrientParams::FlipType::kHorizontal) { + return utils_->FlipHorizontally(*tmp_frame_buffer, output_buffer); + } else { + return utils_->FlipVertically(*tmp_frame_buffer, output_buffer); + } +} + +absl::Status FrameBufferUtils::Execute( + const FrameBuffer& buffer, + const std::vector<FrameBufferOperation>& operations, + FrameBuffer* output_buffer) { + // Reference variables to swapping input and output buffers for each command. + FrameBuffer input_frame_buffer = buffer; + FrameBuffer temp_frame_buffer = buffer; + + // Temporary buffers and its size to hold intermediate results. + int buffer1_size = 0; + int buffer2_size = 0; + std::unique_ptr<uint8[]> buffer1; + std::unique_ptr<uint8[]> buffer2; + + for (int i = 0; i < operations.size(); i++) { + const FrameBufferOperation& operation = operations[i]; + + // The first command's input is always passed in `buffer`. Before + // process each command, the input_frame_buffer is pointed at the previous + // command's output buffer. + if (i == 0) { + input_frame_buffer = buffer; + } else { + input_frame_buffer = temp_frame_buffer; + } + + // Calculates the resulting metadata from the command and the input. + FrameBuffer::Dimension new_size = GetSize(input_frame_buffer, operation); + FrameBuffer::Orientation new_orientation = + GetOrientation(input_frame_buffer, operation); + FrameBuffer::Format new_format = GetFormat(input_frame_buffer, operation); + int byte_size = GetBufferByteSize(new_size, new_format); + + // The last command's output buffer is always passed in `output_buffer`. + // For other commands, we create temporary FrameBuffer for processing. + if ((i + 1) == operations.size()) { + temp_frame_buffer = *output_buffer; + // Validate the `output_buffer` metadata mathes with command line chain + // resulting metadata. + if (temp_frame_buffer.format() != new_format || + temp_frame_buffer.orientation() != new_orientation || + temp_frame_buffer.dimension() != new_size) { + return absl::InvalidArgumentError( + "The output metadata does not match pipeline result metadata."); + } + } else { + // Create a temporary buffer to hold intermediate results. For simplicity, + // we only create one continuous memory with no padding for intermediate + // results. + // + // We hold maximum 2 temporary buffers in memory at any given time. + // + // The pipeline is a linear chain. The output buffer from previous command + // becomes the input buffer for the next command. We simply use odd / even + // index to swap between buffers. + std::vector<FrameBuffer::Plane> planes; + if (i % 2 == 0) { + if (buffer1_size < byte_size) { + buffer1_size = byte_size; + buffer1 = absl::make_unique<uint8[]>(byte_size); + } + planes = GetPlanes(buffer1.get(), new_size, new_format); + } else { + if (buffer2_size < byte_size) { + buffer2_size = byte_size; + buffer2 = absl::make_unique<uint8[]>(byte_size); + } + planes = GetPlanes(buffer2.get(), new_size, new_format); + } + if (planes.empty()) { + return absl::InternalError("Failed to construct temporary buffer."); + } + temp_frame_buffer = FrameBuffer(planes, new_size, new_format, + new_orientation, buffer.timestamp()); + } + RETURN_IF_ERROR(Execute(input_frame_buffer, operation, &temp_frame_buffer)); + } + return absl::OkStatus(); +} + +absl::Status FrameBufferUtils::Preprocess( + const FrameBuffer& buffer, absl::optional<BoundingBox> bounding_box, + FrameBuffer* output_buffer) { + std::vector<FrameBufferOperation> frame_buffer_operations; + // Handle cropping and resizing. + bool needs_dimension_swap = + RequireDimensionSwap(buffer.orientation(), output_buffer->orientation()); + // For intermediate steps, we need to use dimensions based on the input + // orientation. + FrameBuffer::Dimension pre_orient_dimension = output_buffer->dimension(); + if (needs_dimension_swap) { + pre_orient_dimension.Swap(); + } + + if (bounding_box.has_value()) { + // Cropping case. + frame_buffer_operations.push_back(CropResizeOperation( + bounding_box.value().origin_x(), bounding_box.value().origin_y(), + FrameBuffer::Dimension{bounding_box.value().width(), + bounding_box.value().height()}, + pre_orient_dimension)); + } else if (pre_orient_dimension != buffer.dimension()) { + // Resizing case. + frame_buffer_operations.push_back( + CropResizeOperation(0, 0, buffer.dimension(), pre_orient_dimension)); + } + + // Handle color space conversion. + if (output_buffer->format() != buffer.format()) { + frame_buffer_operations.push_back( + ConvertOperation(output_buffer->format())); + } + + // Handle orientation conversion. + if (output_buffer->orientation() != buffer.orientation()) { + frame_buffer_operations.push_back( + OrientOperation(output_buffer->orientation())); + } + + // Execute the processing pipeline. + if (frame_buffer_operations.empty()) { + // Using resize to perform copy. + RETURN_IF_ERROR(Resize(buffer, output_buffer)); + } else { + RETURN_IF_ERROR(Execute(buffer, frame_buffer_operations, output_buffer)); + } + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h new file mode 100644 index 00000000..90a7491e --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h @@ -0,0 +1,292 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_H_ + +#include <memory> +#include <vector> + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h" + +namespace tflite { +namespace task { +namespace vision { + +// Returns the minimal buffer size for a plane in bytes based on the given +// format and dimensions. +int GetBufferByteSize(FrameBuffer::Dimension dimension, + FrameBuffer::Format format); + +// Rotates the `from_box` in `from_orientation` to `to_orientation` within an +// image of size `from_dimension`. +BoundingBox OrientBoundingBox(const BoundingBox& from_box, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, + FrameBuffer::Dimension from_dimension); + +// Same as OrientBoundingBox but from normalized coordinates. +BoundingBox OrientAndDenormalizeBoundingBox( + float from_left, float from_top, float from_right, float from_bottom, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, + FrameBuffer::Dimension from_dimension); + +// Rotates `(from_x, from_y)` coordinates from an image of dimension +// `from_dimension` and orientation `from_orientation` into `(to_x, to_y)` +// coordinates with orientation `to_orientation`. +void OrientCoordinates(int from_x, int from_y, + FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation, + FrameBuffer::Dimension from_dimension, int* to_x, + int* to_y); + +// Returns whether the conversion from from_orientation to to_orientation +// requires 90 or 270 degrees rotation. +bool RequireDimensionSwap(FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation); + +// Structure to express parameters needed to achieve orientation conversion. +struct OrientParams { + // Counterclockwise rotation angle in degrees. This is expressed as a + // multiple of 90 degrees. + int rotation_angle_deg; + // Flipping operation. It must come after the rotation. + enum class FlipType { kHorizontal, kVertical }; + absl::optional<FlipType> flip; +}; + +// Returns rotation angle and the need for horizontal flipping or vertical +// flipping. +OrientParams GetOrientParams(FrameBuffer::Orientation from_orientation, + FrameBuffer::Orientation to_orientation); + +// The parameters needed to crop / resize. +// +// The coordinate system has its origin at the upper left corner, and +// positive values extend down and to the right from it. +// +// After the operation, the `crop_origin` will become the new origin. +// `crop_width` and `crop_height` defines the desired cropping region. After +// cropping, a resize is performed based on the `resize_width` and +// `resize_height`. +// +// To perform just cropping, the `crop_width` and `crop_height` should be the +// same as `resize_width` `and resize_height`. +struct CropResizeOperation { + CropResizeOperation(int crop_origin_x, int crop_origin_y, + FrameBuffer::Dimension crop_dimension, + FrameBuffer::Dimension resize_dimension) + : crop_origin_x(crop_origin_x), + crop_origin_y(crop_origin_y), + crop_dimension(crop_dimension), + resize_dimension(resize_dimension) {} + + int crop_origin_x; + int crop_origin_y; + FrameBuffer::Dimension crop_dimension; + FrameBuffer::Dimension resize_dimension; +}; + +// The parameters needed to convert to the specified format. +struct ConvertOperation { + explicit ConvertOperation(FrameBuffer::Format to_format) + : to_format(to_format) {} + FrameBuffer::Format to_format; +}; + +// The parameters needed to change the orientation. +struct OrientOperation { + explicit OrientOperation(FrameBuffer::Orientation to_orientation) + : to_orientation(to_orientation) {} + FrameBuffer::Orientation to_orientation; +}; + +// A variant of the supported operations on FrameBuffers. Alias for user +// convenience. +using FrameBufferOperation = + absl::variant<CropResizeOperation, ConvertOperation, OrientOperation>; + +// Image processing utility. This utility provides both basic image buffer +// manipulations (e.g. rotation, format conversion, resizing, etc) as well as +// capability for chaining pipeline executions. The actual buffer processing +// engine is configurable to allow optimization based on platforms. +// +// Examples: +// +// // Create an instance of FrameBufferUtils with Halide processing engine. +// std::unique_ptr<FrameBufferUtils> utils = FrameBufferUtils::Create(kHalide); +// +// // Perform single basic operation by each individual call. +// std::unique_ptr<FrameBuffer> input = FrameBuffer::Create(...); +// std::unique_ptr<FrameBuffer> output = FrameBuffer::Create(...); +// utils->Orient(*input, output.get()); +// utils->Resize(*input, output.get()); +// +// // Chaining processing operations. +// const std::vector<FrameBufferOperation> operations = { +// ConvertOperation(FrameBuffer::Format::kNV21), +// CropResizeOperation(/*crop_origin_x=*/20, /*crop_origin_y=*/20, +// /*crop_width=*/10, /*crop_height=*/10, +// /*resize_width=*/10, /*resize_height=*/10), +// OrientOperation(FrameBuffer::Orientation::kLeftTop)}; +// utils->Execute(*input, operations, output.get()); +class FrameBufferUtils { + public: + // Counter-clockwise rotation in degree. + enum class RotationDegree { k0 = 0, k90 = 1, k180 = 2, k270 = 3 }; + + // Underlying process engine used for performing operations. + enum class ProcessEngine { + kLibyuv, + }; + + // Factory method FrameBufferUtils instance. The processing engine is + // defined by `engine`. + static std::unique_ptr<FrameBufferUtils> Create(ProcessEngine engine) { + return absl::make_unique<FrameBufferUtils>(engine); + } + + explicit FrameBufferUtils(ProcessEngine engine); + + // Performs cropping operation. + // + // The coordinate system has its origin at the upper left corner, and + // positive values extend down and to the right from it. After cropping, + // (x0, y0) becomes (0, 0). The new width and height are + // (x1 - x0 + 1, y1 - y0 + 1). + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. If the `output_buffer` + // size dimension does not match with crop dimension, then a resize is + // automatically performed. + absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + FrameBuffer* output_buffer); + + // Performs resizing operation. + // + // The resize dimension is determined based on output_buffer's size metadata. + // + // The output_buffer should have metadata populated and its backing buffer + // should be big enough to store the operation result. + absl::Status Resize(const FrameBuffer& buffer, FrameBuffer* output_buffer); + + // Performs rotation operation. + // + // The rotation is specified in counter-clockwise direction. + // + // The output_buffer should have metadata populated and its backing buffer + // should be big enough to store the operation result. + absl::Status Rotate(const FrameBuffer& buffer, RotationDegree rotation, + FrameBuffer* output_buffer); + + // Performs horizontal flip operation. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + absl::Status FlipHorizontally(const FrameBuffer& buffer, + FrameBuffer* output_buffer); + + // Performs vertical flip operation. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + absl::Status FlipVertically(const FrameBuffer& buffer, + FrameBuffer* output_buffer); + + // Performs buffer format conversion. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + absl::Status Convert(const FrameBuffer& buffer, FrameBuffer* output_buffer); + + // Performs buffer orientation conversion. Depends on the orientations, this + // method may perform rotation and optional flipping operations. + // + // If `buffer` and `output_buffer` has the same orientation, then a copy + // operation will performed. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + absl::Status Orient(const FrameBuffer& buffer, FrameBuffer* output_buffer); + + // Performs the image processing operations specified, in that order. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + absl::Status Execute(const FrameBuffer& buffer, + const std::vector<FrameBufferOperation>& operations, + FrameBuffer* output_buffer); + + // Performs a chain of operations to convert `buffer` to desired metadata + // (width, height, format, orientation) defined by `output_buffer` and + // optional cropping (`bounding_box`). + // + // Internally, a chain of operations is constructed. For performance + // optimization, operations are performed in the following order: crop, + // resize, convert color space format, and rotate. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. Insufficient backing + // buffer size may cause garbage result or crash. Use `GetBufferByteSize` to + // calculate the minimal buffer size. + // + // If the `buffer` is already in desired format, then an extra copy will be + // performed. + // + // The input param `bounding_box` is defined in the `buffer` coordinate space. + absl::Status Preprocess(const FrameBuffer& buffer, + absl::optional<BoundingBox> bounding_box, + FrameBuffer* output_buffer); + + private: + // Returns the new FrameBuffer size after the operation is applied. + FrameBuffer::Dimension GetSize(const FrameBuffer& buffer, + const FrameBufferOperation& operation); + + // Returns the new FrameBuffer orientation after command is processed. + FrameBuffer::Orientation GetOrientation( + const FrameBuffer& buffer, const FrameBufferOperation& operation); + + // Returns the new FrameBuffer format after command is processed. + FrameBuffer::Format GetFormat(const FrameBuffer& buffer, + const FrameBufferOperation& operation); + + // Returns Plane struct based on one dimension buffer and its metadata. If + // an error occurred, it will return an empty vector. + std::vector<FrameBuffer::Plane> GetPlanes(const uint8* buffer, + FrameBuffer::Dimension dimension, + FrameBuffer::Format format); + + // Executes command with params. + absl::Status Execute(const FrameBuffer& buffer, + const FrameBufferOperation& operation, + FrameBuffer* output_buffer); + + // Execution engine conforms to FrameBufferUtilsInterface. + std::unique_ptr<FrameBufferUtilsInterface> utils_; +}; + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_H_ diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h new file mode 100644 index 00000000..502e998d --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_INTERFACE_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_INTERFACE_H_ + +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" + +namespace tflite { +namespace task { +namespace vision { + +// Interface for the FrameBuffer image processing library. +class FrameBufferUtilsInterface { + public: + virtual ~FrameBufferUtilsInterface() = default; + + // Crops `buffer` to the specified points. + // + // The coordinate system has its origin at the upper left corner, and + // positive values extend down and to the right from it. After cropping, + // the top left point becomes (0, 0). The new width and height are + // (x1 - x0 + 1, y1 - y0 + 1). + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + virtual absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, + int y1, FrameBuffer* output_buffer) = 0; + + // Resizes `buffer` to the size of the given `output_buffer`. + // + // The resize dimension is determined based on the size of `output_buffer`. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + virtual absl::Status Resize(const FrameBuffer& buffer, + FrameBuffer* output_buffer) = 0; + + // Rotates `buffer` counter-clockwise by the given `angle_deg` (in degrees). + // + // When rotating by 90 degrees, the top-right corner of `buffer` becomes + // the top-left corner of `output_buffer`. The given angle must be a multiple + // of 90 degrees. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + virtual absl::Status Rotate(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) = 0; + + // Flips `buffer` horizontally. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + virtual absl::Status FlipHorizontally(const FrameBuffer& buffer, + FrameBuffer* output_buffer) = 0; + + // Flips `buffer` vertically. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + virtual absl::Status FlipVertically(const FrameBuffer& buffer, + FrameBuffer* output_buffer) = 0; + + // Converts `buffer`'s format to the format of the given `output_buffer`. + // + // The `output_buffer` should have metadata populated and its backing buffer + // should be big enough to store the operation result. + virtual absl::Status Convert(const FrameBuffer& buffer, + FrameBuffer* output_buffer) = 0; +}; +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_FRAME_BUFFER_UTILS_INTERFACE_H_ diff --git a/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc new file mode 100644 index 00000000..51e72fa1 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc @@ -0,0 +1,254 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h" + +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" + +namespace tflite { +namespace task { +namespace vision { +namespace { + +using ::absl::StatusCode; +using ::tflite::ColorSpaceType_RGB; +using ::tflite::ContentProperties; +using ::tflite::ContentProperties_ImageProperties; +using ::tflite::EnumNameContentProperties; +using ::tflite::ImageProperties; +using ::tflite::TensorMetadata; +using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; +using ::tflite::task::core::TfLiteEngine; + +StatusOr<const TensorMetadata*> GetInputTensorMetadataIfAny( + const ModelMetadataExtractor& metadata_extractor) { + if (metadata_extractor.GetModelMetadata() == nullptr || + metadata_extractor.GetModelMetadata()->subgraph_metadata() == nullptr) { + // Some models have no metadata at all (or very partial), so exit early. + return nullptr; + } else if (metadata_extractor.GetInputTensorCount() != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Models are assumed to have a single input TensorMetadata.", + TfLiteSupportStatus::kInvalidNumInputTensorsError); + } + + const TensorMetadata* metadata = metadata_extractor.GetInputTensorMetadata(0); + + if (metadata == nullptr) { + // Should never happen. + return CreateStatusWithPayload(StatusCode::kInternal, + "Input TensorMetadata is null."); + } + + return metadata; +} + +StatusOr<const ImageProperties*> GetImagePropertiesIfAny( + const TensorMetadata& tensor_metadata) { + if (tensor_metadata.content() == nullptr || + tensor_metadata.content()->content_properties() == nullptr) { + return nullptr; + } + + ContentProperties type = tensor_metadata.content()->content_properties_type(); + + if (type != ContentProperties_ImageProperties) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat( + "Expected ImageProperties for tensor ", + tensor_metadata.name() ? tensor_metadata.name()->str() : "#0", + ", got ", EnumNameContentProperties(type), "."), + TfLiteSupportStatus::kMetadataInvalidContentPropertiesError); + } + + return tensor_metadata.content()->content_properties_as_ImageProperties(); +} + +StatusOr<absl::optional<NormalizationOptions>> GetNormalizationOptionsIfAny( + const TensorMetadata& tensor_metadata) { + ASSIGN_OR_RETURN( + const tflite::ProcessUnit* normalization_process_unit, + ModelMetadataExtractor::FindFirstProcessUnit( + tensor_metadata, tflite::ProcessUnitOptions_NormalizationOptions)); + if (normalization_process_unit == nullptr) { + return {absl::nullopt}; + } + const tflite::NormalizationOptions* tf_normalization_options = + normalization_process_unit->options_as_NormalizationOptions(); + const auto mean_values = tf_normalization_options->mean(); + const auto std_values = tf_normalization_options->std(); + if (mean_values->size() != std_values->size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("NormalizationOptions: expected mean and std of same " + "dimension, got ", + mean_values->size(), " and ", std_values->size(), "."), + TfLiteSupportStatus::kMetadataInvalidProcessUnitsError); + } + absl::optional<NormalizationOptions> normalization_options; + if (mean_values->size() == 1) { + normalization_options = NormalizationOptions{ + .mean_values = {mean_values->Get(0), mean_values->Get(0), + mean_values->Get(0)}, + .std_values = {std_values->Get(0), std_values->Get(0), + std_values->Get(0)}, + .num_values = 1}; + } else if (mean_values->size() == 3) { + normalization_options = NormalizationOptions{ + .mean_values = {mean_values->Get(0), mean_values->Get(1), + mean_values->Get(2)}, + .std_values = {std_values->Get(0), std_values->Get(1), + std_values->Get(2)}, + .num_values = 3}; + } else { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("NormalizationOptions: only 1 or 3 mean and std " + "values are supported, got ", + mean_values->size(), "."), + TfLiteSupportStatus::kMetadataInvalidProcessUnitsError); + } + return normalization_options; +} + +} // namespace + +StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs( + const TfLiteEngine::Interpreter& interpreter, + const tflite::metadata::ModelMetadataExtractor& metadata_extractor) { + ASSIGN_OR_RETURN(const TensorMetadata* metadata, + GetInputTensorMetadataIfAny(metadata_extractor)); + + const ImageProperties* props = nullptr; + absl::optional<NormalizationOptions> normalization_options; + if (metadata != nullptr) { + ASSIGN_OR_RETURN(props, GetImagePropertiesIfAny(*metadata)); + ASSIGN_OR_RETURN(normalization_options, + GetNormalizationOptionsIfAny(*metadata)); + } + + if (TfLiteEngine::InputCount(&interpreter) != 1) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Models are assumed to have a single input.", + TfLiteSupportStatus::kInvalidNumInputTensorsError); + } + + // Input-related specifications. + const TfLiteTensor* input_tensor = TfLiteEngine::GetInput(&interpreter, 0); + if (input_tensor->dims->size != 4) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Only 4D tensors in BHWD layout are supported.", + TfLiteSupportStatus::kInvalidInputTensorDimensionsError); + } + static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteFloat32}; + TfLiteType input_type = input_tensor->type; + if (!absl::c_linear_search(valid_types, input_type)) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat( + "Type mismatch for input tensor ", input_tensor->name, + ". Requested one of these types: kTfLiteUint8/kTfLiteFloat32, got ", + TfLiteTypeGetName(input_type), "."), + TfLiteSupportStatus::kInvalidInputTensorTypeError); + } + + // The expected layout is BHWD, i.e. batch x height x width x color + // See https://www.tensorflow.org/guide/tensors + const int batch = input_tensor->dims->data[0]; + const int height = input_tensor->dims->data[1]; + const int width = input_tensor->dims->data[2]; + const int depth = input_tensor->dims->data[3]; + + if (props != nullptr && props->color_space() != ColorSpaceType_RGB) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Only RGB color space is supported for now.", + TfLiteSupportStatus::kInvalidArgumentError); + } + if (batch != 1 || depth != 3) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("The input tensor should have dimensions 1 x height x " + "width x 3. Got ", + batch, " x ", height, " x ", width, " x ", depth, "."), + TfLiteSupportStatus::kInvalidInputTensorDimensionsError); + } + int bytes_size = input_tensor->bytes; + size_t byte_depth = + input_type == kTfLiteFloat32 ? sizeof(float) : sizeof(uint8); + + // Sanity checks. + if (input_type == kTfLiteFloat32) { + if (!normalization_options.has_value()) { + return CreateStatusWithPayload( + absl::StatusCode::kNotFound, + "Input tensor has type kTfLiteFloat32: it requires specifying " + "NormalizationOptions metadata to preprocess input images.", + TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError); + } else if (bytes_size / sizeof(float) % + normalization_options.value().num_values != + 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "The number of elements in the input tensor must be a multiple of " + "the number of normalization parameters.", + TfLiteSupportStatus::kInvalidArgumentError); + } + } + if (width <= 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, "The input width should be positive.", + TfLiteSupportStatus::kInvalidInputTensorDimensionsError); + } + if (height <= 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, "The input height should be positive.", + TfLiteSupportStatus::kInvalidInputTensorDimensionsError); + } + if (bytes_size != height * width * depth * byte_depth) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "The input size in bytes does not correspond to the expected number of " + "pixels.", + TfLiteSupportStatus::kInvalidInputTensorSizeError); + } + + // Note: in the future, additional checks against `props->default_size()` + // might be added. Also, verify that NormalizationOptions, if any, do specify + // a single value when color space is grayscale. + + ImageTensorSpecs result; + result.image_width = width; + result.image_height = height; + result.color_space = ColorSpaceType_RGB; + result.tensor_type = input_type; + result.normalization_options = normalization_options; + + return result; +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h new file mode 100644 index 00000000..536eed4d --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h @@ -0,0 +1,93 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_ + +#include <array> + +#include "absl/types/optional.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/tflite_engine.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace vision { + +// Parameters used for input image normalization when input tensor has +// kTfLiteFloat32 type. +// +// Exactly 1 or 3 values are expected for `mean_values` and `std_values`. In +// case 1 value only is specified, it is used for all channels. E.g. for a RGB +// image, the normalization is done as follow: +// +// (R - mean_values[0]) / std_values[0] +// (G - mean_values[1]) / std_values[1] +// (B - mean_values[2]) / std_values[2] +// +// `num_values` keeps track of how many values have been provided, which should +// be 1 or 3 (see above). In particular, single-channel grayscale images expect +// only 1 value. +struct NormalizationOptions { + std::array<float, 3> mean_values; + std::array<float, 3> std_values; + int num_values; +}; + +// Parameters related to the expected tensor specifications when the tensor +// represents an image. +// +// E.g. input tensor specifications expected by the model at Invoke() time. In +// such a case, and before running inference with the TF Lite interpreter, the +// caller must use these values and perform image preprocessing and/or +// normalization so as to fill the actual input tensor appropriately. +struct ImageTensorSpecs { + // Expected image dimensions, e.g. image_width=224, image_height=224. + int image_width; + int image_height; + // Expected color space, e.g. color_space=RGB. + tflite::ColorSpaceType color_space; + // Expected input tensor type, e.g. if tensor_type=kTfLiteFloat32 the caller + // should usually perform some normalization to convert the uint8 pixels into + // floats (see NormalizationOptions in TF Lite Metadata for more details). + TfLiteType tensor_type; + // Optional normalization parameters read from TF Lite Metadata. Those are + // mandatory when tensor_type=kTfLiteFloat32 in order to convert the input + // image data into the expected range of floating point values, an error is + // returned otherwise (see sanity checks below). They should be ignored for + // other tensor input types, e.g. kTfLiteUInt8. + absl::optional<NormalizationOptions> normalization_options; +}; + +// Performs sanity checks on the expected input tensor including consistency +// checks against model metadata, if any. For now, a single RGB input with BHWD +// layout, where B = 1 and D = 3, is expected. Returns the corresponding input +// specifications if they pass, or an error otherwise (too many input tensors, +// etc). +// Note: both interpreter and metadata extractor *must* be successfully +// initialized before calling this function by means of (respectively): +// - `tflite::InterpreterBuilder`, +// - `tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer`. +tflite::support::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs( + const tflite::task::core::TfLiteEngine::Interpreter& interpreter, + const tflite::metadata::ModelMetadataExtractor& metadata_extractor); + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_ diff --git a/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc b/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc new file mode 100644 index 00000000..beb58eb4 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.cc @@ -0,0 +1,1499 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h" + +#include <stdint.h> + +#include <memory> +#include <string> + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "include/libyuv.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" + +namespace tflite { +namespace task { +namespace vision { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::TfLiteSupportStatus; + +namespace { + +// Converts NV12 `buffer` to the `output_buffer` of the target color space. +// Supported output format includes RGB24 and YV21. +absl::Status ConvertFromNv12(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + switch (output_buffer->format()) { + case FrameBuffer::Format::kRGB: { + // The RAW format of Libyuv represents the 8-bit interleaved RGB format in + // the big endian style with R being the first byte in memory. + int ret = libyuv::NV12ToRAW( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer, + yuv_data.uv_row_stride, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv NV12ToRAW operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kRGBA: { + // The libyuv ABGR format is interleaved RGBA format in memory. + int ret = libyuv::NV12ToABGR( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer, + yuv_data.uv_row_stride, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv NV12ToABGR operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + int ret = libyuv::NV12ToI420( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer, + yuv_data.uv_row_stride, const_cast<uint8_t*>(output_data.y_buffer), + output_data.y_row_stride, const_cast<uint8_t*>(output_data.u_buffer), + output_data.uv_row_stride, const_cast<uint8_t*>(output_data.v_buffer), + output_data.uv_row_stride, output_buffer->dimension().width, + output_buffer->dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv NV12ToI420 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kNV21: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride, + const_cast<uint8*>(output_data.y_buffer), + output_data.y_row_stride, buffer.dimension().width, + buffer.dimension().height); + ASSIGN_OR_RETURN( + const FrameBuffer::Dimension uv_plane_dimension, + GetUvPlaneDimension(buffer.dimension(), buffer.format())); + libyuv::SwapUVPlane(yuv_data.u_buffer, yuv_data.uv_row_stride, + const_cast<uint8*>(output_data.v_buffer), + output_data.uv_row_stride, uv_plane_dimension.width, + uv_plane_dimension.height); + break; + } + case FrameBuffer::Format::kGRAY: { + libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + output_buffer->dimension().width, + output_buffer->dimension().height); + break; + } + default: + return absl::InternalError(absl::StrFormat("Format %i is not supported.", + output_buffer->format())); + } + return absl::OkStatus(); +} + +// Converts NV21 `buffer` into the `output_buffer` of the target color space. +// Supported output format includes RGB24 and YV21. +absl::Status ConvertFromNv21(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + switch (output_buffer->format()) { + case FrameBuffer::Format::kRGB: { + // The RAW format of Libyuv represents the 8-bit interleaved RGB format in + // the big endian style with R being the first byte in memory. + int ret = libyuv::NV21ToRAW( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.v_buffer, + yuv_data.uv_pixel_stride, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv NV21ToRAW operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kRGBA: { + // The libyuv ABGR format is interleaved RGBA format in memory. + int ret = libyuv::NV21ToABGR( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.v_buffer, + yuv_data.uv_pixel_stride, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv NV21ToABGR operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + int ret = libyuv::NV21ToI420( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.v_buffer, + yuv_data.uv_row_stride, const_cast<uint8_t*>(output_data.y_buffer), + output_data.y_row_stride, const_cast<uint8_t*>(output_data.u_buffer), + output_data.uv_row_stride, const_cast<uint8_t*>(output_data.v_buffer), + output_data.uv_row_stride, output_buffer->dimension().width, + output_buffer->dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv NV21ToI420 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kNV12: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride, + const_cast<uint8*>(output_data.y_buffer), + output_data.y_row_stride, buffer.dimension().width, + buffer.dimension().height); + ASSIGN_OR_RETURN( + const FrameBuffer::Dimension uv_plane_dimension, + GetUvPlaneDimension(buffer.dimension(), buffer.format())); + libyuv::SwapUVPlane(yuv_data.v_buffer, yuv_data.uv_row_stride, + const_cast<uint8*>(output_data.u_buffer), + output_data.uv_row_stride, uv_plane_dimension.width, + uv_plane_dimension.height); + break; + } + case FrameBuffer::Format::kGRAY: { + libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + output_buffer->dimension().width, + output_buffer->dimension().height); + break; + } + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", + output_buffer->format()), + TfLiteSupportStatus::kImageProcessingError); + } + return absl::OkStatus(); +} + +// Converts YV12/YV21 `buffer` to the `output_buffer` of the target color space. +// Supported output format includes RGB24, NV12, and NV21. +absl::Status ConvertFromYv(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + switch (output_buffer->format()) { + case FrameBuffer::Format::kRGB: { + // The RAW format of Libyuv represents the 8-bit interleaved RGB format in + // the big endian style with R being the first byte in memory. + int ret = libyuv::I420ToRAW( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer, + yuv_data.uv_row_stride, yuv_data.v_buffer, yuv_data.uv_row_stride, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv I420ToRAW operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kRGBA: { + // The libyuv ABGR format is interleaved RGBA format in memory. + int ret = libyuv::I420ToABGR( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer, + yuv_data.uv_row_stride, yuv_data.v_buffer, yuv_data.uv_row_stride, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv I420ToABGR operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kNV12: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + int ret = libyuv::I420ToNV12( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer, + yuv_data.uv_row_stride, yuv_data.v_buffer, yuv_data.uv_row_stride, + const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride, + const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride, + output_buffer->dimension().width, output_buffer->dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv I420ToNV12 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kNV21: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + int ret = libyuv::I420ToNV21( + yuv_data.y_buffer, yuv_data.y_row_stride, yuv_data.u_buffer, + yuv_data.uv_row_stride, yuv_data.v_buffer, yuv_data.uv_row_stride, + const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride, + const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride, + output_buffer->dimension().width, output_buffer->dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv I420ToNV21 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kGRAY: { + libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + output_buffer->dimension().width, + output_buffer->dimension().height); + break; + } + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + ASSIGN_OR_RETURN( + const FrameBuffer::Dimension uv_plane_dimension, + GetUvPlaneDimension(buffer.dimension(), buffer.format())); + libyuv::CopyPlane(yuv_data.y_buffer, yuv_data.y_row_stride, + const_cast<uint8*>(output_yuv_data.y_buffer), + output_yuv_data.y_row_stride, buffer.dimension().width, + buffer.dimension().height); + libyuv::CopyPlane(yuv_data.u_buffer, yuv_data.uv_row_stride, + const_cast<uint8*>(output_yuv_data.u_buffer), + output_yuv_data.uv_row_stride, uv_plane_dimension.width, + uv_plane_dimension.height); + libyuv::CopyPlane(yuv_data.v_buffer, yuv_data.uv_row_stride, + const_cast<uint8*>(output_yuv_data.v_buffer), + output_yuv_data.uv_row_stride, uv_plane_dimension.width, + uv_plane_dimension.height); + break; + } + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", + output_buffer->format()), + TfLiteSupportStatus::kImageProcessingError); + } + return absl::OkStatus(); +} + +// Resizes YV12/YV21 `buffer` to the target `output_buffer`. +absl::Status ResizeYv(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + // TODO(b/151217096): Choose the optimal image resizing filter to optimize + // the model inference performance. + int ret = libyuv::I420Scale( + input_data.y_buffer, input_data.y_row_stride, input_data.u_buffer, + input_data.uv_row_stride, input_data.v_buffer, input_data.uv_row_stride, + buffer.dimension().width, buffer.dimension().height, + const_cast<uint8_t*>(output_data.y_buffer), output_data.y_row_stride, + const_cast<uint8_t*>(output_data.u_buffer), output_data.uv_row_stride, + const_cast<uint8_t*>(output_data.v_buffer), output_data.uv_row_stride, + output_buffer->dimension().width, output_buffer->dimension().height, + libyuv::FilterMode::kFilterBilinear); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv I420Scale operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); +} + +// Resizes NV12/NV21 `buffer` to the target `output_buffer`. +absl::Status ResizeNv(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + const int buffer_size = + GetFrameBufferByteSize(buffer.dimension(), FrameBuffer::Format::kYV21); + auto yuv_raw_buffer = absl::make_unique<uint8[]>(buffer_size); + ASSIGN_OR_RETURN( + std::unique_ptr<FrameBuffer> yuv_buffer, + CreateFromRawBuffer(yuv_raw_buffer.get(), buffer.dimension(), + FrameBuffer::Format::kYV21, buffer.orientation())); + // TODO(b/151375918): Current implementation is a workaround by converting + // input NV12/NV21 buffer to the YV12 formats, resizing the YV12 buffer, and + // converting the resized YV12 buffer back to the target format. Consider + // optimizes this by adding the support of NV12/NV21 resizing in Libyuv. + if (buffer.format() == FrameBuffer::Format::kNV12) { + RETURN_IF_ERROR(ConvertFromNv12(buffer, yuv_buffer.get())); + } else if (buffer.format() == FrameBuffer::Format::kNV21) { + RETURN_IF_ERROR(ConvertFromNv21(buffer, yuv_buffer.get())); + } else { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + + const int resized_buffer_size = GetFrameBufferByteSize( + output_buffer->dimension(), FrameBuffer::Format::kYV12); + auto resized_yuv_raw_buffer = absl::make_unique<uint8[]>(resized_buffer_size); + ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> resized_yuv_buffer, + CreateFromRawBuffer(resized_yuv_raw_buffer.get(), + output_buffer->dimension(), + FrameBuffer::Format::kYV12, + output_buffer->orientation())); + RETURN_IF_ERROR(ResizeYv(*yuv_buffer, resized_yuv_buffer.get())); + + RETURN_IF_ERROR(ConvertFromYv(*resized_yuv_buffer, output_buffer)); + return absl::OkStatus(); +} + +// Converts `buffer` to libyuv ARGB format and stores the conversion result +// in `dest_argb`. +absl::Status ConvertRgbToArgb(const FrameBuffer& buffer, uint8* dest_argb, + int dest_stride_argb) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + if (buffer.format() != FrameBuffer::Format::kRGB) { + return CreateStatusWithPayload(StatusCode::kInternal, + "RGB input format is expected.", + TfLiteSupportStatus::kImageProcessingError); + } + + if (dest_argb == nullptr || dest_stride_argb <= 0) { + return CreateStatusWithPayload( + StatusCode::kInternal, + "Invalid destination arguments for ConvertRgbToArgb.", + TfLiteSupportStatus::kImageProcessingError); + } + + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + int ret = libyuv::RGB24ToARGB( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + dest_argb, dest_stride_argb, buffer.dimension().width, + buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv RGB24ToARGB operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); +} + +// Converts `src_argb` in libyuv ARGB format to FrameBuffer::kRGB format and +// stores the conversion result in `output_buffer`. +absl::Status ConvertArgbToRgb(uint8* src_argb, int src_stride_argb, + FrameBuffer* output_buffer) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); + if (output_buffer->format() != FrameBuffer::Format::kRGB) { + return absl::InternalError("RGB input format is expected."); + } + + if (src_argb == nullptr || src_stride_argb <= 0) { + return CreateStatusWithPayload( + StatusCode::kInternal, "Invalid source arguments for ConvertArgbToRgb.", + TfLiteSupportStatus::kImageProcessingError); + } + + if (output_buffer->plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + output_buffer->format()), + TfLiteSupportStatus::kImageProcessingError); + } + int ret = libyuv::ARGBToRGB24( + src_argb, src_stride_argb, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + output_buffer->dimension().width, output_buffer->dimension().height); + + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ARGBToRGB24 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); +} + +// Converts `buffer` in FrameBuffer::kRGBA format to libyuv ARGB (BGRA in +// memory) format and stores the conversion result in `dest_argb`. +absl::Status ConvertRgbaToArgb(const FrameBuffer& buffer, uint8* dest_argb, + int dest_stride_argb) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + if (buffer.format() != FrameBuffer::Format::kRGBA) { + return CreateStatusWithPayload( + StatusCode::kInternal, "RGBA input format is expected.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + + if (dest_argb == nullptr || dest_stride_argb <= 0) { + return CreateStatusWithPayload( + StatusCode::kInternal, + "Invalid source arguments for ConvertRgbaToArgb.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + + int ret = libyuv::ABGRToARGB( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + dest_argb, dest_stride_argb, buffer.dimension().width, + buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kInternal, "Libyuv ABGRToARGB operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); +} + +// Converts kRGB `buffer` to the `output_buffer` of the target color space. +absl::Status ConvertFromRgb(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + if (output_buffer->format() == FrameBuffer::Format::kGRAY) { + int ret = libyuv::RAWToJ400( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kInternal, "Libyuv RAWToJ400 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); + } else if (output_buffer->format() == FrameBuffer::Format::kYV12 || + output_buffer->format() == FrameBuffer::Format::kYV21 || + output_buffer->format() == FrameBuffer::Format::kNV12 || + output_buffer->format() == FrameBuffer::Format::kNV21) { + // libyuv does not support conversion directly from kRGB to kNV12 / kNV21. + // For kNV12 / kNV21, the implementation converts the kRGB to I420, + // then converts I420 to kNV12 / kNV21. + // TODO(b/153000936): use libyuv::RawToNV12 / libyuv::RawToNV21 when they + // are ready. + FrameBuffer::YuvData yuv_data; + std::unique_ptr<uint8[]> tmp_yuv_buffer; + std::unique_ptr<FrameBuffer> yuv_frame_buffer; + if (output_buffer->format() == FrameBuffer::Format::kNV12 || + output_buffer->format() == FrameBuffer::Format::kNV21) { + tmp_yuv_buffer = absl::make_unique<uint8[]>( + GetFrameBufferByteSize(buffer.dimension(), output_buffer->format())); + ASSIGN_OR_RETURN( + yuv_frame_buffer, + CreateFromRawBuffer(tmp_yuv_buffer.get(), buffer.dimension(), + FrameBuffer::Format::kYV21, + output_buffer->orientation())); + ASSIGN_OR_RETURN( + yuv_data, FrameBuffer::GetYuvDataFromFrameBuffer(*yuv_frame_buffer)); + } else { + ASSIGN_OR_RETURN(yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + } + int ret = libyuv::RAWToI420( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(yuv_data.y_buffer), yuv_data.y_row_stride, + const_cast<uint8*>(yuv_data.u_buffer), yuv_data.uv_row_stride, + const_cast<uint8*>(yuv_data.v_buffer), yuv_data.uv_row_stride, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kInternal, "Libyuv RAWToI420 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + if (output_buffer->format() == FrameBuffer::Format::kNV12 || + output_buffer->format() == FrameBuffer::Format::kNV21) { + return ConvertFromYv(*yuv_frame_buffer, output_buffer); + } + return absl::OkStatus(); + } + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", output_buffer->format()), + TfLiteSupportStatus::kImageProcessingError); +} + +// Converts kRGBA `buffer` to the `output_buffer` of the target color space. +absl::Status ConvertFromRgba(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + switch (output_buffer->format()) { + case FrameBuffer::Format::kGRAY: { + // libyuv does not support convert kRGBA (ABGR) foramat. In this method, + // the implementation converts kRGBA format to ARGB and use ARGB buffer + // for conversion. + // TODO(b/141181395): Use libyuv::ABGRToJ400 when it is ready. + + // Convert kRGBA to ARGB + int argb_buffer_size = GetFrameBufferByteSize(buffer.dimension(), + FrameBuffer::Format::kRGBA); + auto argb_buffer = absl::make_unique<uint8[]>(argb_buffer_size); + const int argb_row_bytes = buffer.dimension().width * kRgbaPixelBytes; + RETURN_IF_ERROR( + ConvertRgbaToArgb(buffer, argb_buffer.get(), argb_row_bytes)); + + // Convert ARGB to kGRAY + int ret = libyuv::ARGBToJ400( + argb_buffer.get(), argb_row_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ARGBToJ400 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kNV12: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + int ret = libyuv::ABGRToNV12( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride, + const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ABGRToNV12 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kNV21: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + int ret = libyuv::ABGRToNV21( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride, + const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ABGRToNV21 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: { + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + int ret = libyuv::ABGRToI420( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride, + const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride, + const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ABGRToI420 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + case FrameBuffer::Format::kRGB: { + // ARGB is BGRA in memory and RGB24 is BGR in memory. The removal of the + // alpha channel will not impact the RGB ordering. + int ret = libyuv::ARGBToRGB24( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ABGRToRGB24 operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + break; + } + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Convert Rgba to format %i is not supported.", + output_buffer->format()), + TfLiteSupportStatus::kImageProcessingError); + } + return absl::OkStatus(); +} + +// Returns libyuv rotation based on counter-clockwise angle_deg. +libyuv::RotationMode GetLibyuvRotationMode(int angle_deg) { + switch (angle_deg) { + case 90: + return libyuv::kRotate270; + case 270: + return libyuv::kRotate90; + case 180: + return libyuv::kRotate180; + default: + return libyuv::kRotate0; + } +} + +absl::Status RotateRgba(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + + // libyuv::ARGBRotate assumes RGBA buffer is in the interleaved format. + int ret = libyuv::ARGBRotate( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, buffer.dimension().width, + buffer.dimension().height, GetLibyuvRotationMode(angle_deg % 360)); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ARGBRotate operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); +} + +absl::Status RotateRgb(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) { + // libyuv does not support rotate kRGB (RGB24) foramat. In this method, the + // implementation converts kRGB format to ARGB and use ARGB buffer for + // rotation. The result is then convert back to RGB. + + // Convert RGB to ARGB + int argb_buffer_size = + GetFrameBufferByteSize(buffer.dimension(), FrameBuffer::Format::kRGBA); + auto argb_buffer = absl::make_unique<uint8[]>(argb_buffer_size); + const int argb_row_bytes = buffer.dimension().width * kRgbaPixelBytes; + RETURN_IF_ERROR(ConvertRgbToArgb(buffer, argb_buffer.get(), argb_row_bytes)); + + // Rotate ARGB + auto argb_rotated_buffer = absl::make_unique<uint8[]>(argb_buffer_size); + int rotated_row_bytes = output_buffer->dimension().width * kRgbaPixelBytes; + // TODO(b/151954340): Optimize the current implementation by utilizing + // ARGBMirror for 180 degree rotation. + int ret = libyuv::ARGBRotate( + argb_buffer.get(), argb_row_bytes, argb_rotated_buffer.get(), + rotated_row_bytes, buffer.dimension().width, buffer.dimension().height, + GetLibyuvRotationMode(angle_deg % 360)); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ARGBRotate operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + + // Convert ARGB to RGB + return ConvertArgbToRgb(argb_rotated_buffer.get(), rotated_row_bytes, + output_buffer); +} + +absl::Status RotateGray(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + int ret = libyuv::RotatePlane( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, buffer.dimension().width, + buffer.dimension().height, GetLibyuvRotationMode(angle_deg % 360)); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv RotatePlane operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); +} + +// Rotates YV12/YV21 frame buffer. +absl::Status RotateYv(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + int ret = libyuv::I420Rotate( + input_data.y_buffer, input_data.y_row_stride, input_data.u_buffer, + input_data.uv_row_stride, input_data.v_buffer, input_data.uv_row_stride, + const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride, + const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride, + const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride, + buffer.dimension().width, buffer.dimension().height, + GetLibyuvRotationMode(angle_deg)); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv I420Rotate operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); +} + +// Rotates NV12/NV21 frame buffer. +// TODO(b/152097364): Refactor NV12/NV21 rotation after libyuv explicitly +// support that. +absl::Status RotateNv(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) { + if (buffer.format() != FrameBuffer::Format::kNV12 && + buffer.format() != FrameBuffer::Format::kNV21) { + return CreateStatusWithPayload(StatusCode::kInternal, + "kNV12 or kNV21 input formats are expected.", + TfLiteSupportStatus::kImageProcessingError); + } + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + const int rotated_buffer_size = GetFrameBufferByteSize( + output_buffer->dimension(), FrameBuffer::Format::kYV21); + auto rotated_yuv_raw_buffer = absl::make_unique<uint8[]>(rotated_buffer_size); + ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> rotated_yuv_buffer, + CreateFromRawBuffer( + rotated_yuv_raw_buffer.get(), output_buffer->dimension(), + /*target_format=*/FrameBuffer::Format::kYV21, + output_buffer->orientation())); + ASSIGN_OR_RETURN(FrameBuffer::YuvData rotated_yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*rotated_yuv_buffer)); + // Get the first chroma plane and use it as the u plane. This is a workaround + // for optimizing NV21 rotation. For NV12, the implementation is logical + // correct. For NV21, use v plane as u plane will make the UV planes swapped + // in the intermediate rotated I420 frame. The output buffer is finally built + // by merging the swapped UV planes which produces V first interleaved UV + // buffer. + const uint8* chroma_buffer = buffer.format() == FrameBuffer::Format::kNV12 + ? input_data.u_buffer + : input_data.v_buffer; + // Rotate the Y plane and store into the Y plane in `output_buffer`. Rotate + // the interleaved UV plane and store into the interleaved UV plane in + // `rotated_yuv_buffer`. + int ret = libyuv::NV12ToI420Rotate( + input_data.y_buffer, input_data.y_row_stride, chroma_buffer, + input_data.uv_row_stride, const_cast<uint8*>(output_data.y_buffer), + output_data.y_row_stride, const_cast<uint8*>(rotated_yuv_data.u_buffer), + rotated_yuv_data.uv_row_stride, + const_cast<uint8*>(rotated_yuv_data.v_buffer), + rotated_yuv_data.uv_row_stride, buffer.dimension().width, + buffer.dimension().height, GetLibyuvRotationMode(angle_deg % 360)); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv Nv12ToI420Rotate operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + // Merge rotated UV planes into the output buffer. For NV21, the UV buffer of + // the intermediate I420 frame is swapped. MergeUVPlane builds the interleaved + // VU buffer for NV21 by putting the U plane in the I420 frame which is + // actually the V plane from the input buffer first. + const uint8* output_chroma_buffer = + buffer.format() == FrameBuffer::Format::kNV12 ? output_data.u_buffer + : output_data.v_buffer; + // The width and height arguments of `libyuv::MergeUVPlane()` represent the + // width and height of the UV planes. + libyuv::MergeUVPlane( + rotated_yuv_data.u_buffer, rotated_yuv_data.uv_row_stride, + rotated_yuv_data.v_buffer, rotated_yuv_data.uv_row_stride, + const_cast<uint8*>(output_chroma_buffer), output_data.uv_row_stride, + (output_buffer->dimension().width + 1) / 2, + (output_buffer->dimension().height + 1) / 2); + return absl::OkStatus(); +} + +// This method only supports kGRAY, kRGB, and kRGBA format. +absl::Status FlipPlaneVertically(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + + ASSIGN_OR_RETURN(int pixel_stride, GetPixelStrides(buffer.format())); + + // Flip vertically is achieved by passing in negative height. + libyuv::CopyPlane(buffer.plane(0).buffer, + buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + output_buffer->dimension().width * pixel_stride, + -output_buffer->dimension().height); + + return absl::OkStatus(); +} + +// This method only supports kGRAY, kRGBA, and kRGB formats. +absl::Status CropPlane(const FrameBuffer& buffer, int x0, int y0, int x1, + int y1, FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + + ASSIGN_OR_RETURN(int pixel_stride, GetPixelStrides(buffer.format())); + FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1); + + // Cropping is achieved by adjusting origin to (x0, y0). + int adjusted_offset = + buffer.plane(0).stride.row_stride_bytes * y0 + x0 * pixel_stride; + + libyuv::CopyPlane(buffer.plane(0).buffer + adjusted_offset, + buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + crop_dimension.width * pixel_stride, crop_dimension.height); + + return absl::OkStatus(); +} + +// Crops NV12/NV21 FrameBuffer to the subregion defined by the top left pixel +// position (x0, y0) and the bottom right pixel position (x1, y1). +absl::Status CropNv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + // Crop Y plane by copying the buffer with the origin offset to (x0, y0). + int crop_offset_y = input_data.y_row_stride * y0 + x0; + int crop_width = x1 - x0 + 1; + int crop_height = y1 - y0 + 1; + libyuv::CopyPlane(input_data.y_buffer + crop_offset_y, + input_data.y_row_stride, + const_cast<uint8*>(output_data.y_buffer), + output_data.y_row_stride, crop_width, crop_height); + // Crop chroma plane by copying the buffer with the origin offset to + // (x0 / 2, y0 / 2); + // TODO(b/152629712): Investigate the impact of color shifting caused by the + // bounding box with odd X or Y starting positions. + int crop_offset_chroma = input_data.uv_row_stride * (y0 / 2) + + input_data.uv_pixel_stride * (x0 / 2); + ASSIGN_OR_RETURN(const uint8* input_chroma_buffer, GetUvRawBuffer(buffer)); + ASSIGN_OR_RETURN(const uint8* output_chroma_buffer, + GetUvRawBuffer(*output_buffer)); + libyuv::CopyPlane( + input_chroma_buffer + crop_offset_chroma, input_data.uv_row_stride, + const_cast<uint8*>(output_chroma_buffer), output_data.uv_row_stride, + /*width=*/(crop_width + 1) / 2 * 2, /*height=*/(crop_height + 1) / 2); + return absl::OkStatus(); +} + +// Crops YV12/YV21 FrameBuffer to the subregion defined by the top left pixel +// position (x0, y0) and the bottom right pixel position (x1, y1). +absl::Status CropYv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + // Crop Y plane by copying the buffer with the origin offset to (x0, y0). + int crop_offset_y = input_data.y_row_stride * y0 + x0; + FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1); + libyuv::CopyPlane( + input_data.y_buffer + crop_offset_y, input_data.y_row_stride, + const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride, + crop_dimension.width, crop_dimension.height); + // Crop U plane by copying the buffer with the origin offset to + // (x0 / 2, y0 / 2). + ASSIGN_OR_RETURN(const FrameBuffer::Dimension crop_uv_dimension, + GetUvPlaneDimension(crop_dimension, buffer.format())); + // TODO(b/152629712): Investigate the impact of color shifting caused by the + // bounding box with odd X or Y starting positions. + int crop_offset_chroma = input_data.uv_row_stride * (y0 / 2) + + input_data.uv_pixel_stride * (x0 / 2); + libyuv::CopyPlane( + input_data.u_buffer + crop_offset_chroma, input_data.uv_row_stride, + const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride, + crop_uv_dimension.width, crop_uv_dimension.height); + // Crop V plane by copying the buffer with the origin offset to + // (x0 / 2, y0 / 2); + libyuv::CopyPlane( + input_data.v_buffer + crop_offset_chroma, input_data.uv_row_stride, + const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride, + /*width=*/(crop_dimension.width + 1) / 2, + /*height=*/(crop_dimension.height + 1) / 2); + return absl::OkStatus(); +} + +absl::Status CropResizeYuv(const FrameBuffer& buffer, int x0, int y0, int x1, + int y1, FrameBuffer* output_buffer) { + FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1); + if (crop_dimension == output_buffer->dimension()) { + switch (buffer.format()) { + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + return CropNv(buffer, x0, y0, x1, y1, output_buffer); + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return CropYv(buffer, x0, y0, x1, y1, output_buffer); + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + } + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + // Cropping YUV planes by offsetting the origins of each plane. + // TODO(b/152629712): Investigate the impact of color shifting caused by the + // bounding box with odd X or Y starting positions. + const int plane_y_offset = input_data.y_row_stride * y0 + x0; + const int plane_uv_offset = input_data.uv_row_stride * (y0 / 2) + + input_data.uv_pixel_stride * (x0 / 2); + FrameBuffer::Plane cropped_plane_y = { + /*buffer=*/input_data.y_buffer + plane_y_offset, + /*stride=*/{input_data.y_row_stride, /*pixel_stride_bytes=*/1}}; + FrameBuffer::Plane cropped_plane_u = { + /*buffer=*/input_data.u_buffer + plane_uv_offset, + /*stride=*/{input_data.uv_row_stride, input_data.uv_pixel_stride}}; + FrameBuffer::Plane cropped_plane_v = { + /*buffer=*/input_data.v_buffer + plane_uv_offset, + /*stride=*/{input_data.uv_row_stride, input_data.uv_pixel_stride}}; + + switch (buffer.format()) { + case FrameBuffer::Format::kNV12: { + std::unique_ptr<FrameBuffer> cropped_buffer = FrameBuffer::Create( + {cropped_plane_y, cropped_plane_u, cropped_plane_v}, crop_dimension, + buffer.format(), buffer.orientation()); + return ResizeNv(*cropped_buffer, output_buffer); + } + case FrameBuffer::Format::kNV21: { + std::unique_ptr<FrameBuffer> cropped_buffer = FrameBuffer::Create( + {cropped_plane_y, cropped_plane_v, cropped_plane_u}, crop_dimension, + buffer.format(), buffer.orientation()); + return ResizeNv(*cropped_buffer, output_buffer); + } + case FrameBuffer::Format::kYV12: { + std::unique_ptr<FrameBuffer> cropped_buffer = FrameBuffer::Create( + {cropped_plane_y, cropped_plane_v, cropped_plane_u}, crop_dimension, + buffer.format(), buffer.orientation()); + return ResizeYv(*cropped_buffer, output_buffer); + } + case FrameBuffer::Format::kYV21: { + std::unique_ptr<FrameBuffer> cropped_buffer = FrameBuffer::Create( + {cropped_plane_y, cropped_plane_u, cropped_plane_v}, crop_dimension, + buffer.format(), buffer.orientation()); + return ResizeYv(*cropped_buffer, output_buffer); + } + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + return absl::OkStatus(); +} + +absl::Status FlipHorizontallyRgba(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + + int ret = libyuv::ARGBMirror( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + output_buffer->dimension().width, output_buffer->dimension().height); + + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ARGBMirror operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + + return absl::OkStatus(); +} + +// Flips `buffer` horizontally and store the result in `output_buffer`. This +// method assumes all buffers have pixel stride equals to 1. +absl::Status FlipHorizontallyPlane(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + libyuv::MirrorPlane( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + output_buffer->dimension().width, output_buffer->dimension().height); + + return absl::OkStatus(); +} + +absl::Status ResizeRgb(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + + // libyuv doesn't support scale kRGB (RGB24) foramat. In this method, + // the implementation converts kRGB format to ARGB and use ARGB buffer for + // scaling. The result is then convert back to RGB. + + // Convert RGB to ARGB + int argb_buffer_size = + GetFrameBufferByteSize(buffer.dimension(), FrameBuffer::Format::kRGBA); + auto argb_buffer = absl::make_unique<uint8[]>(argb_buffer_size); + const int argb_row_bytes = buffer.dimension().width * kRgbaPixelBytes; + RETURN_IF_ERROR(ConvertRgbToArgb(buffer, argb_buffer.get(), argb_row_bytes)); + + // Resize ARGB + int resized_argb_buffer_size = GetFrameBufferByteSize( + output_buffer->dimension(), FrameBuffer::Format::kRGBA); + auto resized_argb_buffer = + absl::make_unique<uint8[]>(resized_argb_buffer_size); + int resized_argb_row_bytes = + output_buffer->dimension().width * kRgbaPixelBytes; + int ret = libyuv::ARGBScale( + argb_buffer.get(), argb_row_bytes, buffer.dimension().width, + buffer.dimension().height, resized_argb_buffer.get(), + resized_argb_row_bytes, output_buffer->dimension().width, + output_buffer->dimension().height, libyuv::FilterMode::kFilterBilinear); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ARGBScale operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + + // Convert ARGB to RGB + return ConvertArgbToRgb(resized_argb_buffer.get(), resized_argb_row_bytes, + output_buffer); +} + +// Horizontally flip `buffer` and store the result in `output_buffer`. +absl::Status FlipHorizontallyRgb(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + +#if LIBYUV_VERSION >= 1747 + int ret = libyuv::RGB24Mirror( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, buffer.dimension().width, + buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv RGB24Mirror operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + + return absl::OkStatus(); +#else +#error LibyuvFrameBufferUtils requires LIBYUV_VERSION 1747 or above +#endif // LIBYUV_VERSION >= 1747 +} + +absl::Status ResizeRgba(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + int ret = libyuv::ARGBScale( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + output_buffer->dimension().width, output_buffer->dimension().height, + libyuv::FilterMode::kFilterBilinear); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv ARGBScale operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); +} + +// Flips NV12/NV21 FrameBuffer horizontally. +absl::Status FlipHorizontallyNv(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + ASSIGN_OR_RETURN(const uint8* input_chroma_buffer, GetUvRawBuffer(buffer)); + ASSIGN_OR_RETURN(const uint8* output_chroma_buffer, + GetUvRawBuffer(*output_buffer)); + + int ret = libyuv::NV12Mirror( + input_data.y_buffer, input_data.y_row_stride, input_chroma_buffer, + input_data.uv_row_stride, const_cast<uint8*>(output_data.y_buffer), + output_data.y_row_stride, const_cast<uint8*>(output_chroma_buffer), + output_data.uv_row_stride, buffer.dimension().width, + buffer.dimension().height); + + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv NV12Mirror operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + + return absl::OkStatus(); +} + +// Flips YV12/YV21 FrameBuffer horizontally. +absl::Status FlipHorizontallyYv(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + int ret = libyuv::I420Mirror( + input_data.y_buffer, input_data.y_row_stride, input_data.u_buffer, + input_data.uv_row_stride, input_data.v_buffer, input_data.uv_row_stride, + const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride, + const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride, + const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride, + buffer.dimension().width, buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv I420Mirror operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + + return absl::OkStatus(); +} + +// Flips NV12/NV21 FrameBuffer vertically. +absl::Status FlipVerticallyNv(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + // Flip Y plane vertically by passing a negative height. + libyuv::CopyPlane(input_data.y_buffer, input_data.y_row_stride, + const_cast<uint8*>(output_data.y_buffer), + output_data.y_row_stride, buffer.dimension().width, + -output_buffer->dimension().height); + // Flip UV plane vertically by passing a negative height. + ASSIGN_OR_RETURN(const uint8* input_chroma_buffer, GetUvRawBuffer(buffer)); + ASSIGN_OR_RETURN(const uint8* output_chroma_buffer, + GetUvRawBuffer(*output_buffer)); + ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_plane_dimension, + GetUvPlaneDimension(buffer.dimension(), buffer.format())); + libyuv::CopyPlane( + input_chroma_buffer, input_data.uv_row_stride, + const_cast<uint8*>(output_chroma_buffer), output_data.uv_row_stride, + /*width=*/uv_plane_dimension.width * 2, -uv_plane_dimension.height); + return absl::OkStatus(); +} + +// Flips NV12/NV21 FrameBuffer vertically. +absl::Status FlipVerticallyYv(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData input_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + ASSIGN_OR_RETURN(FrameBuffer::YuvData output_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_buffer)); + // Flip buffer vertically by passing a negative height. + int ret = libyuv::I420Copy( + input_data.y_buffer, input_data.y_row_stride, input_data.u_buffer, + input_data.uv_row_stride, input_data.v_buffer, input_data.uv_row_stride, + const_cast<uint8*>(output_data.y_buffer), output_data.y_row_stride, + const_cast<uint8*>(output_data.u_buffer), output_data.uv_row_stride, + const_cast<uint8*>(output_data.v_buffer), output_data.uv_row_stride, + buffer.dimension().width, -buffer.dimension().height); + if (ret != 0) { + return CreateStatusWithPayload( + StatusCode::kUnknown, "Libyuv I420Copy operation failed.", + TfLiteSupportStatus::kImageProcessingBackendError); + } + return absl::OkStatus(); +} + +// Resize `buffer` to metadata defined in `output_buffer`. This +// method assumes buffer has pixel stride equals to 1 (grayscale equivalent). +absl::Status ResizeGray(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + if (buffer.plane_count() > 1) { + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Only single plane is supported for format %i.", + buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } + libyuv::ScalePlane( + buffer.plane(0).buffer, buffer.plane(0).stride.row_stride_bytes, + buffer.dimension().width, buffer.dimension().height, + const_cast<uint8*>(output_buffer->plane(0).buffer), + output_buffer->plane(0).stride.row_stride_bytes, + output_buffer->dimension().width, output_buffer->dimension().height, + libyuv::FilterMode::kFilterBilinear); + return absl::OkStatus(); +} + +// This method only supports kGRAY, kRGBA, and kRGB formats. +absl::Status CropResize(const FrameBuffer& buffer, int x0, int y0, int x1, + int y1, FrameBuffer* output_buffer) { + FrameBuffer::Dimension crop_dimension = GetCropDimension(x0, x1, y0, y1); + if (crop_dimension == output_buffer->dimension()) { + return CropPlane(buffer, x0, y0, x1, y1, output_buffer); + } + + ASSIGN_OR_RETURN(int pixel_stride, GetPixelStrides(buffer.format())); + // Cropping is achieved by adjusting origin to (x0, y0). + int adjusted_offset = + buffer.plane(0).stride.row_stride_bytes * y0 + x0 * pixel_stride; + FrameBuffer::Plane plane = { + /*buffer=*/buffer.plane(0).buffer + adjusted_offset, + /*stride=*/{buffer.plane(0).stride.row_stride_bytes, pixel_stride}}; + auto adjusted_buffer = + FrameBuffer::Create({plane}, crop_dimension, buffer.format(), + buffer.orientation(), buffer.timestamp()); + + switch (buffer.format()) { + case FrameBuffer::Format::kRGB: + return ResizeRgb(*adjusted_buffer, output_buffer); + case FrameBuffer::Format::kRGBA: + return ResizeRgba(*adjusted_buffer, output_buffer); + case FrameBuffer::Format::kGRAY: + return ResizeGray(*adjusted_buffer, output_buffer); + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } +} +} // namespace + +absl::Status LibyuvFrameBufferUtils::Crop(const FrameBuffer& buffer, int x0, + int y0, int x1, int y1, + FrameBuffer* output_buffer) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); + RETURN_IF_ERROR( + ValidateCropBufferInputs(buffer, *output_buffer, x0, y0, x1, y1)); + RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer)); + + switch (buffer.format()) { + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + case FrameBuffer::Format::kGRAY: + return CropResize(buffer, x0, y0, x1, y1, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return CropResizeYuv(buffer, x0, y0, x1, y1, output_buffer); + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } +} + +absl::Status LibyuvFrameBufferUtils::Resize(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + RETURN_IF_ERROR(ValidateResizeBufferInputs(buffer, *output_buffer)); + switch (buffer.format()) { + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return ResizeYv(buffer, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + return ResizeNv(buffer, output_buffer); + case FrameBuffer::Format::kRGB: + return ResizeRgb(buffer, output_buffer); + case FrameBuffer::Format::kRGBA: + return ResizeRgba(buffer, output_buffer); + case FrameBuffer::Format::kGRAY: + return ResizeGray(buffer, output_buffer); + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } +} + +absl::Status LibyuvFrameBufferUtils::Rotate(const FrameBuffer& buffer, + int angle_deg, + FrameBuffer* output_buffer) { + RETURN_IF_ERROR( + ValidateRotateBufferInputs(buffer, *output_buffer, angle_deg)); + RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer)); + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); + + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + return RotateGray(buffer, angle_deg, output_buffer); + case FrameBuffer::Format::kRGBA: + return RotateRgba(buffer, angle_deg, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + return RotateNv(buffer, angle_deg, output_buffer); + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return RotateYv(buffer, angle_deg, output_buffer); + case FrameBuffer::Format::kRGB: + return RotateRgb(buffer, angle_deg, output_buffer); + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } +} + +absl::Status LibyuvFrameBufferUtils::FlipHorizontally( + const FrameBuffer& buffer, FrameBuffer* output_buffer) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); + RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer)); + RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer)); + + switch (buffer.format()) { + case FrameBuffer::Format::kRGBA: + return FlipHorizontallyRgba(buffer, output_buffer); + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return FlipHorizontallyYv(buffer, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + return FlipHorizontallyNv(buffer, output_buffer); + case FrameBuffer::Format::kRGB: + return FlipHorizontallyRgb(buffer, output_buffer); + case FrameBuffer::Format::kGRAY: + return FlipHorizontallyPlane(buffer, output_buffer); + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } +} + +absl::Status LibyuvFrameBufferUtils::FlipVertically( + const FrameBuffer& buffer, FrameBuffer* output_buffer) { + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(buffer)); + RETURN_IF_ERROR(ValidateBufferPlaneMetadata(*output_buffer)); + RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer)); + RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer)); + + switch (buffer.format()) { + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + case FrameBuffer::Format::kGRAY: + return FlipPlaneVertically(buffer, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + return FlipVerticallyNv(buffer, output_buffer); + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return FlipVerticallyYv(buffer, output_buffer); + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } +} + +absl::Status LibyuvFrameBufferUtils::Convert(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + RETURN_IF_ERROR( + ValidateConvertFormats(buffer.format(), output_buffer->format())); + switch (buffer.format()) { + case FrameBuffer::Format::kNV12: + return ConvertFromNv12(buffer, output_buffer); + case FrameBuffer::Format::kNV21: + return ConvertFromNv21(buffer, output_buffer); + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return ConvertFromYv(buffer, output_buffer); + case FrameBuffer::Format::kRGB: + return ConvertFromRgb(buffer, output_buffer); + case FrameBuffer::Format::kRGBA: + return ConvertFromRgba(buffer, output_buffer); + default: + return CreateStatusWithPayload( + StatusCode::kInternal, + absl::StrFormat("Format %i is not supported.", buffer.format()), + TfLiteSupportStatus::kImageProcessingError); + } +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h b/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h new file mode 100644 index 00000000..0d001c8c --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/libyuv_frame_buffer_utils.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_LIBYUV_FRAME_BUFFER_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_LIBYUV_FRAME_BUFFER_UTILS_H_ + +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils_interface.h" + +namespace tflite { +namespace task { +namespace vision { + +// Libyuv image processing engine conforms to FrameBufferUtilsInterface. +// Although this class provides public APIs, it is recommended to use the public +// APIs defined in frame_buffer_utils.h for higher level abstraction and better +// functionality support. +class LibyuvFrameBufferUtils : public FrameBufferUtilsInterface { + public: + LibyuvFrameBufferUtils() = default; + ~LibyuvFrameBufferUtils() override = default; + + // Crops input `buffer` to the specified subregions and resizes the cropped + // region to the target image resolution defined by the `output_buffer`. + // + // (x0, y0) represents the top-left point of the buffer. + // (x1, y1) represents the bottom-right point of the buffer. + // + // Crop region dimensions must be equal or smaller than input `buffer` + // dimensions. + absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + FrameBuffer* output_buffer) override; + + // Resizes `buffer` to the size of the given `output_buffer`. + absl::Status Resize(const FrameBuffer& buffer, + FrameBuffer* output_buffer) override; + + // Rotates `buffer` counter-clockwise by the given `angle_deg` (in degrees). + // + // The given angle must be a multiple of 90 degrees. + absl::Status Rotate(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) override; + + // Flips `buffer` horizontally. + absl::Status FlipHorizontally(const FrameBuffer& buffer, + FrameBuffer* output_buffer) override; + + // Flips `buffer` vertically. + absl::Status FlipVertically(const FrameBuffer& buffer, + FrameBuffer* output_buffer) override; + + // Converts `buffer`'s format to the format of the given `output_buffer`. + // + // Grayscale format cannot be converted to other formats. + absl::Status Convert(const FrameBuffer& buffer, + FrameBuffer* output_buffer) override; +}; + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_LIBYUV_FRAME_BUFFER_UTILS_H_ diff --git a/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc b/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc new file mode 100644 index 00000000..773ab76f --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/score_calibration.cc @@ -0,0 +1,225 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h" + +#include <cmath> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" + +namespace tflite { +namespace task { +namespace vision { +namespace { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +// Used to prevent log(<=0.0) in ClampedLog() calls. +constexpr float kLogScoreMinimum = 1e-16; + +// Returns the following, depending on x: +// x => threshold: log(x) +// x < threshold: 2 * log(thresh) - log(2 * thresh - x) +// This form (a) is anti-symmetric about the threshold and (b) has continuous +// value and first derivative. This is done to prevent taking the log of values +// close to 0 which can lead to floating point errors and is better than simple +// clamping since it preserves order for scores less than the threshold. +float ClampedLog(float x, float threshold) { + if (x < threshold) { + return 2.0 * std::log(static_cast<double>(threshold)) - + log(2.0 * threshold - x); + } + return std::log(static_cast<double>(x)); +} + +// Applies the specified score transformation to the provided score. +// Currently supports the following, +// IDENTITY : f(x) = x +// LOG : f(x) = log(x) +// INVERSE_LOGISTIC : f(x) = log(x) - log(1-x) +float ApplyScoreTransformation(float score, const ScoreTransformation& type) { + switch (type) { + case ScoreTransformation::kIDENTITY: + return score; + case ScoreTransformation::kINVERSE_LOGISTIC: + return (ClampedLog(score, kLogScoreMinimum) - + ClampedLog(1.0 - score, kLogScoreMinimum)); + case ScoreTransformation::kLOG: + return ClampedLog(score, kLogScoreMinimum); + } +} + +// Builds a single Sigmoid from the label name and associated CSV file line. +StatusOr<Sigmoid> SigmoidFromLabelAndLine(absl::string_view label, + absl::string_view line) { + std::vector<absl::string_view> str_params = absl::StrSplit(line, ','); + if (str_params.size() != 3 && str_params.size() != 4) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected 3 or 4 parameters per line in score " + "calibration file, got %d.", + str_params.size()), + TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError); + } + std::vector<float> float_params(4); + for (int i = 0; i < str_params.size(); ++i) { + if (!absl::SimpleAtof(str_params[i], &float_params[i])) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Could not parse score calibration parameter as float: %s.", + str_params[i]), + TfLiteSupportStatus::kMetadataMalformedScoreCalibrationError); + } + } + Sigmoid sigmoid; + sigmoid.label = std::string(label); + sigmoid.scale = float_params[0]; + sigmoid.slope = float_params[1]; + sigmoid.offset = float_params[2]; + if (str_params.size() == 4) { + sigmoid.min_uncalibrated_score = float_params[3]; + } + return sigmoid; +} + +// Converts a tflite::ScoreTransformationType to its +// tflite::task::vision::ScoreTransformation equivalent. +ScoreTransformation ConvertScoreTransformationType( + tflite::ScoreTransformationType type) { + switch (type) { + case tflite::ScoreTransformationType_IDENTITY: + return ScoreTransformation::kIDENTITY; + case tflite::ScoreTransformationType_LOG: + return ScoreTransformation::kLOG; + case tflite::ScoreTransformationType_INVERSE_LOGISTIC: + return ScoreTransformation::kINVERSE_LOGISTIC; + } +} + +} // namespace + +std::ostream& operator<<(std::ostream& os, const Sigmoid& s) { + os << s.label << "," << s.slope << "," << s.offset << "," << s.scale; + if (s.min_uncalibrated_score.has_value()) { + os << "," << s.min_uncalibrated_score.value(); + } + return os; +} + +ScoreCalibration::ScoreCalibration() {} +ScoreCalibration::~ScoreCalibration() {} + +absl::Status ScoreCalibration::InitializeFromParameters( + const SigmoidCalibrationParameters& params) { + sigmoid_parameters_ = std::move(params); + // Fill in the map from label -> sigmoid. + sigmoid_parameters_map_.clear(); + for (const auto& sigmoid : sigmoid_parameters_.sigmoid) { + sigmoid_parameters_map_.insert_or_assign(sigmoid.label, sigmoid); + } + return absl::OkStatus(); +} + +float ScoreCalibration::ComputeCalibratedScore(const std::string& label, + float uncalibrated_score) const { + absl::optional<Sigmoid> sigmoid = FindSigmoidParameters(label); + if (!sigmoid.has_value() || + (sigmoid.value().min_uncalibrated_score.has_value() && + uncalibrated_score < sigmoid.value().min_uncalibrated_score.value())) { + return sigmoid_parameters_.default_score; + } + + float transformed_score = ApplyScoreTransformation( + uncalibrated_score, sigmoid_parameters_.score_transformation); + float scale_shifted_score = + transformed_score * sigmoid.value().slope + sigmoid.value().offset; + + // For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0 + // and exp(x) / (1+exp(x)) when scale_shifted_score < 0. + if (scale_shifted_score >= 0.0) { + return sigmoid.value().scale / + (1.0 + std::exp(static_cast<double>(-scale_shifted_score))); + } else { + float score_exp = std::exp(static_cast<double>(scale_shifted_score)); + return sigmoid.value().scale * score_exp / (1.0 + score_exp); + } +} + +absl::optional<Sigmoid> ScoreCalibration::FindSigmoidParameters( + const std::string& label) const { + auto it = sigmoid_parameters_map_.find(label); + if (it != sigmoid_parameters_map_.end()) { + return it->second; + } else if (sigmoid_parameters_.default_sigmoid.has_value()) { + return sigmoid_parameters_.default_sigmoid.value(); + } + return absl::nullopt; +} + +StatusOr<SigmoidCalibrationParameters> BuildSigmoidCalibrationParams( + const tflite::ScoreCalibrationOptions& score_calibration_options, + absl::string_view score_calibration_file, + const std::vector<LabelMapItem>& label_map_items) { + // Split file lines and perform sanity checks. + if (score_calibration_file.empty()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Expected non-empty score calibration file."); + } + std::vector<absl::string_view> lines = + absl::StrSplit(score_calibration_file, '\n'); + if (label_map_items.size() != lines.size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of labels (%d) and score " + "calibration parameters (%d).", + label_map_items.size(), lines.size()), + TfLiteSupportStatus::kMetadataNumLabelsMismatchError); + } + // Initialize SigmoidCalibrationParameters with its class-agnostic parameters. + SigmoidCalibrationParameters sigmoid_params = {}; + sigmoid_params.score_transformation = ConvertScoreTransformationType( + score_calibration_options.score_transformation()); + sigmoid_params.default_score = score_calibration_options.default_score(); + std::vector<Sigmoid> sigmoid_vector; + // Fill sigmoids for each class with parameters in the file. + for (int i = 0; i < label_map_items.size(); ++i) { + if (lines[i].empty()) { + continue; + } + ASSIGN_OR_RETURN(Sigmoid sigmoid, SigmoidFromLabelAndLine( + label_map_items[i].name, lines[i])); + sigmoid_vector.emplace_back(std::move(sigmoid)); + } + sigmoid_params.sigmoid = std::move(sigmoid_vector); + + return sigmoid_params; +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h b/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h new file mode 100644 index 00000000..c3f0bf8a --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/utils/score_calibration.h @@ -0,0 +1,146 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_ + +#include <iostream> +#include <map> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace task { +namespace vision { + +// Sigmoid structure. +struct Sigmoid { + Sigmoid() : scale(1.0) {} + Sigmoid(std::string label, float slope, float offset, float scale = 1.0, + absl::optional<float> min_uncalibrated_score = absl::nullopt) + : label(label), + slope(slope), + offset(offset), + scale(scale), + min_uncalibrated_score(min_uncalibrated_score) {} + + bool operator==(const Sigmoid& other) const { + return label == other.label && slope == other.slope && + offset == other.offset && scale == other.scale && + min_uncalibrated_score == other.min_uncalibrated_score; + } + + // Unique label corresponding to the sigmoid parameters. + std::string label; + float slope; + float offset; + float scale; + absl::optional<float> min_uncalibrated_score; +}; + +std::ostream& operator<<(std::ostream& os, const Sigmoid& s); + +// Transformation function to use for computing transformation scores. +enum class ScoreTransformation { + kIDENTITY, // f(x) = x + kLOG, // f(x) = log(x) + kINVERSE_LOGISTIC // f(x) = log(x) - log(1 - x) +}; + +// Sigmoid calibration parameters. +struct SigmoidCalibrationParameters { + SigmoidCalibrationParameters() + : default_score(0.0), + score_transformation(ScoreTransformation::kIDENTITY) {} + explicit SigmoidCalibrationParameters( + std::vector<Sigmoid> sigmoid, + ScoreTransformation score_transformation = ScoreTransformation::kIDENTITY, + absl::optional<Sigmoid> default_sigmoid = absl::nullopt, + float default_score = 0.0) + : sigmoid(sigmoid), + default_sigmoid(default_sigmoid), + default_score(default_score), + score_transformation(score_transformation) {} + // A vector of Sigmoid associated to the ScoreCalibration instance. + std::vector<Sigmoid> sigmoid; + // If set, this sigmoid will be applied to any non-matching labels. + absl::optional<Sigmoid> default_sigmoid; + // The default score for non-matching labels. Only used if default_sigmoid + // isn't set. + float default_score; + // Function for computing a transformation score prior to sigmoid fitting. + ScoreTransformation score_transformation; +}; + +// This class is used to calibrate predicted scores so that scores are +// comparable across labels. Depending on the particular calibration parameters +// being used, the calibrated scores can also be approximately interpreted as a +// likelihood of being correct. For a given TF Lite model, such parameters are +// typically obtained from TF Lite Metadata (see ScoreCalibrationOptions). +class ScoreCalibration { + public: + ScoreCalibration(); + ~ScoreCalibration(); + + // Transfers input parameters and construct a label to sigmoid map. + absl::Status InitializeFromParameters( + const SigmoidCalibrationParameters& params); + + // Returns a calibrated score given a label string and uncalibrated score. The + // calibrated score will be in the range [0.0, 1.0] and can loosely be + // interpreted as a likelihood of the label being correct. + float ComputeCalibratedScore(const std::string& label, + float uncalibrated_score) const; + + private: + // Finds the sigmoid parameters corresponding to the provided label. + absl::optional<Sigmoid> FindSigmoidParameters(const std::string& label) const; + + // Parameters for internal states. + SigmoidCalibrationParameters sigmoid_parameters_; + + // Maps label strings to the particular sigmoid stored in sigmoid_parameters_. + absl::flat_hash_map<std::string, Sigmoid> sigmoid_parameters_map_; +}; + +// Builds SigmoidCalibrationParameters using data obtained from TF Lite Metadata +// (see ScoreCalibrationOptions in metadata schema). +// +// The provided `score_calibration_file` represents the contents of the score +// calibration associated file (TENSOR_AXIS_SCORE_CALIBRATION), i.e. one set of +// parameters (scale, slope, etc) per line. Each line must be in 1:1 +// correspondence with `label_map_items`, so as to associate each sigmoid to its +// corresponding label name. Returns an error if no valid parameters could be +// built (e.g. malformed parameters). +tflite::support::StatusOr<SigmoidCalibrationParameters> +BuildSigmoidCalibrationParams( + const tflite::ScoreCalibrationOptions& score_calibration_options, + absl::string_view score_calibration_file, + const std::vector<LabelMapItem>& label_map_items); + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_SCORE_CALIBRATION_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/BUILD b/tensorflow_lite_support/cc/text/tokenizers/BUILD new file mode 100644 index 00000000..3ad8da2f --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/BUILD @@ -0,0 +1,191 @@ +# This package contains C++ support libraries that Java libraries can invoke. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load( + "@org_tensorflow//tensorflow/lite:build_def.bzl", + "tflite_copts", + "tflite_jni_binary", +) + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tokenizer", + hdrs = [ + "tokenizer.h", + ], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "tokenizer_jni_lib", + srcs = [ + "tokenizer_jni_lib.cc", + ], + hdrs = [ + "tokenizer_jni_lib.h", + ], + deps = [ + ":tokenizer", + "//tensorflow_lite_support/cc/utils:jni_utils", + "@org_tensorflow//tensorflow/lite/java/jni", + ], +) + +cc_library( + name = "bert_tokenizer", + srcs = [ + "bert_tokenizer.cc", + ], + hdrs = [ + "bert_tokenizer.h", + ], + deps = [ + ":tokenizer", + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/utils:common_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_googlesource_code_re2//:re2", + "@org_tensorflow_text//tensorflow_text/core/kernels:regex_split", + "@org_tensorflow_text//tensorflow_text/core/kernels:wordpiece_tokenizer", + ], +) + +cc_library( + name = "bert_tokenizer_jni_lib", + srcs = [ + "bert_tokenizer_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + ":bert_tokenizer", + ":tokenizer_jni_lib", + "//tensorflow_lite_support/cc/utils:jni_utils", + "@com_google_absl//absl/memory", + "@org_tensorflow//tensorflow/lite/java/jni", + ], + alwayslink = 1, +) + +tflite_jni_binary( + name = "libbert_tokenizer_jni.so", + deps = [ + ":bert_tokenizer_jni_lib", + ], +) + +cc_library( + name = "bert_tokenizer_runtime", + srcs = ["libbert_tokenizer_jni.so"], + alwayslink = 1, +) + +android_library( + name = "bert_tokenizer_jni", + custom_package = "org.tensorflow.lite.support.text", + manifest = "DummyManifest.xml", + resource_files = [], + deps = [ + ":bert_tokenizer_runtime", # build_cleaner: skip + ], +) + +cc_library( + name = "sentencepiece_tokenizer", + hdrs = [ + "sentencepiece_tokenizer.h", + ], + deps = [ + ":tokenizer", + "@com_google_sentencepiece//src:sentencepiece_processor", + ], +) + +cc_library( + name = "sentencepiece_jni_lib", + srcs = [ + "sentencepiece_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + ":sentencepiece_tokenizer", + ":tokenizer_jni_lib", + "//tensorflow_lite_support/cc/utils:jni_utils", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/java/jni", + ], + alwayslink = 1, +) + +cc_library( + name = "sentencepiece_runtime", + srcs = ["libsentencepiece_jni.so"], + alwayslink = 1, +) + +tflite_jni_binary( + name = "libsentencepiece_jni.so", + deps = [ + ":sentencepiece_jni_lib", + ], +) + +android_library( + name = "sentencepiece_jni", + custom_package = "org.tensorflow.lite.support.text", + manifest = "DummyManifest.xml", + resource_files = [], + deps = [ + ":sentencepiece_runtime", # build_cleaner: skip + ], +) + +cc_library( + name = "tokenizer_utils", + srcs = ["tokenizer_utils.cc"], + hdrs = [ + "tokenizer_utils.h", + ], + deps = [ + ":bert_tokenizer", + ":regex_tokenizer", + ":sentencepiece_tokenizer", + ":tokenizer", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "//tensorflow_lite_support/metadata/cc:metadata_extractor", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "regex_tokenizer", + srcs = [ + "regex_tokenizer.cc", + ], + hdrs = [ + "regex_tokenizer.h", + ], + deps = [ + ":tokenizer", + "//tensorflow_lite_support/cc/utils:common_utils", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) diff --git a/tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml b/tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml new file mode 100644 index 00000000..ff025072 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/DummyManifest.xml @@ -0,0 +1,19 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + Copyright 2019 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.support.text"> +</manifest> diff --git a/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc new file mode 100644 index 00000000..aeb887c6 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.cc @@ -0,0 +1,108 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" + +#include "tensorflow_lite_support/cc/port/integral_types.h" + +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + +FlatHashMapBackedWordpiece::FlatHashMapBackedWordpiece( + const std::vector<std::string>& vocab) + : vocab_{vocab} { + for (int i = 0; i < vocab_.size(); ++i) { + index_map_[vocab_[i]] = i; + } +} + +tensorflow::text::LookupStatus FlatHashMapBackedWordpiece::Contains( + absl::string_view key, bool* value) const { + *value = index_map_.contains(key); + return tensorflow::text::LookupStatus(); +} + +bool FlatHashMapBackedWordpiece::LookupId(const absl::string_view key, + int* result) const { + auto it = index_map_.find(key); + if (it == index_map_.end()) { + return false; + } + *result = it->second; + return true; +} + +bool FlatHashMapBackedWordpiece::LookupWord(int vocab_id, + absl::string_view* result) const { + if (vocab_id >= vocab_.size() || vocab_id < 0) { + return false; + } + *result = vocab_[vocab_id]; + return true; +} + +TokenizerResult BertTokenizer::Tokenize(const std::string& input) { + return TokenizeWordpiece(input); +} + +WordpieceTokenizerResult BertTokenizer::TokenizeWordpiece( + const std::string& input) { + WordpieceTokenizerResult result; + std::vector<std::string>& subwords = result.subwords; + std::vector<int>& wp_absolute_begin_offset = result.wp_begin_offset; + std::vector<int>& wp_absolute_end_offset = result.wp_end_offset; + + std::vector<absl::string_view> tokens; + std::vector<int64> begin_offsets; + std::vector<int64> end_offsets; + + // Run through tokenize function + tensorflow::text::RegexSplit(input, delim_re_, true, include_delim_re_, + &tokens, &begin_offsets, &end_offsets); + + for (int token_index = 0; token_index < tokens.size(); token_index++) { + auto& token = tokens[token_index]; + int num_word_pieces = 0; + tensorflow::text::LookupStatus status = WordpieceTokenize( + token, options_.max_bytes_per_token, options_.max_chars_per_subtoken, + options_.suffix_indicator, options_.use_unknown_token, + options_.unknown_token, options_.split_unknown_chars, &vocab_, + &subwords, &wp_absolute_begin_offset, &wp_absolute_end_offset, + &num_word_pieces); + + result.row_lengths.emplace_back(num_word_pieces); + // for the last num_word_pieces added into wp_absolute_begin_offset and + // wp_absolute_end_offset, offset them with begin_offsets[token_index] + int absolute_offset_size = wp_absolute_begin_offset.size(); + for (int i = num_word_pieces; i > 0; i--) { + wp_absolute_begin_offset[absolute_offset_size - i] += + begin_offsets[token_index]; + wp_absolute_end_offset[absolute_offset_size - i] += + begin_offsets[token_index]; + } + if (!status.success) { + return result; + } + } + + return result; +} + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h new file mode 100644 index 00000000..14a006c2 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h @@ -0,0 +1,149 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ + +#include <fstream> +#include <string> +#include <vector> + +#include "absl/container/flat_hash_map.h" +#include "re2/re2.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/cc/utils/common_utils.h" +#include "tensorflow_text/core/kernels/regex_split.h" +#include "tensorflow_text/core/kernels/wordpiece_tokenizer.h" + +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + +constexpr char kDefaultDelimRe[] = + R"((\s+|[!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))"; +constexpr char kDefaultIncludeDelimRe[] = + R"(([!-/]|[:-@]|[\[-`]|[{-~]|[\p{P}]|[\x{4E00}-\x{9FFF}]|[\x{3400}-\x{4DBF}]|[\x{20000}-\x{2A6DF}]|[\x{2A700}-\x{2B73F}]|[\x{2B740}-\x{2B81F}]|[\x{2B820}-\x{2CEAF}]|[\x{F900}-\x{FAFF}]|[\x{2F800}-\x{2FA1F}]))"; +constexpr int kDefaultMaxBytesPerToken = 100; +constexpr int kDefaultMaxCharsPerSubToken = 100; +constexpr char kDefaultSuffixIndicator[] = "##"; +constexpr bool kDefaultUseUnknownToken = true; +constexpr char kDefaultUnknownToken[] = "[UNK]"; +constexpr bool kDefaultSplitUnknownChars = false; + +// Result of wordpiece tokenization including subwords and offsets. +// Example: +// input: tokenize me please +// subwords: token ##ize me plea ##se +// wp_begin_offset: [0, 5, 9, 12, 16] +// wp_end_offset: [ 5, 8, 11, 16, 18] +// row_lengths: [2, 1, 1] +struct WordpieceTokenizerResult : TokenizerResult { + std::vector<int> wp_begin_offset; + std::vector<int> wp_end_offset; + std::vector<int> row_lengths; +}; +// Options to create a BertTokenizer. +struct BertTokenizerOptions { + int max_bytes_per_token = kDefaultMaxBytesPerToken; + int max_chars_per_subtoken = kDefaultMaxCharsPerSubToken; + std::string suffix_indicator = kDefaultSuffixIndicator; + bool use_unknown_token = kDefaultUseUnknownToken; + std::string unknown_token = kDefaultUnknownToken; + bool split_unknown_chars = kDefaultSplitUnknownChars; + std::string delim_str = kDefaultDelimRe; + std::string include_delim_str = kDefaultIncludeDelimRe; +}; + +// A flat-hash-map based implementation of WordpieceVocab, used in +// BertTokenizer to invoke tensorflow::text::WordpieceTokenize within. +class FlatHashMapBackedWordpiece : public tensorflow::text::WordpieceVocab { + public: + explicit FlatHashMapBackedWordpiece(const std::vector<std::string>& vocab); + + tensorflow::text::LookupStatus Contains(absl::string_view key, + bool* value) const override; + bool LookupId(absl::string_view key, int* result) const; + bool LookupWord(int vocab_id, absl::string_view* result) const; + int VocabularySize() const { return vocab_.size(); } + + private: + // All words indexed position in vocabulary file. + std::vector<std::string> vocab_; + absl::flat_hash_map<absl::string_view, int> index_map_; +}; + +// Wordpiece tokenizer for bert models. Initialized with a vocab file or vector. +class BertTokenizer : public tflite::support::text::tokenizer::Tokenizer { + public: + // Initialize the tokenizer from vocab vector and tokenizer configs. + explicit BertTokenizer(const std::vector<std::string>& vocab, + const BertTokenizerOptions& options = {}) + : vocab_{FlatHashMapBackedWordpiece(vocab)}, + options_{options}, + delim_re_{options.delim_str}, + include_delim_re_{options.include_delim_str} {} + + // Initialize the tokenizer from file path to vocab and tokenizer configs. + explicit BertTokenizer(const std::string& path_to_vocab, + const BertTokenizerOptions& options = {}) + : BertTokenizer(utils::LoadVocabFromFile(path_to_vocab), options) {} + + // Initialize the tokenizer from buffer and size of vocab and tokenizer + // configs. + BertTokenizer(const char* vocab_buffer_data, size_t vocab_buffer_size, + const BertTokenizerOptions& options = {}) + : BertTokenizer( + utils::LoadVocabFromBuffer(vocab_buffer_data, vocab_buffer_size), + options) {} + + // Perform tokenization, return tokenized results containing the subwords. + TokenizerResult Tokenize(const std::string& input) override; + + // Perform tokenization, return wordpiece-specific tokenized result including + // subwords and offsets + WordpieceTokenizerResult TokenizeWordpiece(const std::string& input); + + // Check if a certain key is included in the vocab. + tensorflow::text::LookupStatus Contains(const absl::string_view key, + bool* value) const { + return vocab_.Contains(key, value); + } + + // Find the id of a wordpiece. + bool LookupId(absl::string_view key, int* result) const override { + return vocab_.LookupId(key, result); + } + + // Find the wordpiece from an id. + bool LookupWord(int vocab_id, absl::string_view* result) const override { + return vocab_.LookupWord(vocab_id, result); + } + + int VocabularySize() const { return vocab_.VocabularySize(); } + + private: + tflite::support::text::tokenizer::FlatHashMapBackedWordpiece vocab_; + BertTokenizerOptions options_; + RE2 delim_re_; + RE2 include_delim_re_; +}; + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_BERT_TOKENIZER_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc new file mode 100644 index 00000000..442d06ec --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer_jni.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include <string> + +#include "absl/memory/memory.h" +#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace support { + +using ::tflite::support::text::tokenizer::BertTokenizer; +using ::tflite::support::text::tokenizer::BertTokenizerOptions; +using ::tflite::support::utils::StringListToVector; + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeLoadResource( // NOLINT + JNIEnv* env, jobject thiz, jobject vocab_list, jint max_bytes_per_token, + jint max_chars_per_sub_token, jstring jsuffix_indicator, + jboolean use_unknown_token, jstring junknown_token, + jboolean split_unknown_chars) { + // Convert java.util.List<String> into std::vector<string> + std::vector<std::string> vocab = StringListToVector(env, vocab_list); + + // Convert jstrings to std::string + const char* raw_suffix_indicator = + env->GetStringUTFChars(jsuffix_indicator, JNI_FALSE); + std::string suffix_indicator(raw_suffix_indicator); + + const char* raw_unknown_token = + env->GetStringUTFChars(junknown_token, JNI_FALSE); + std::string unknown_token(raw_unknown_token); + + auto handle = absl::make_unique<BertTokenizer>( + vocab, BertTokenizerOptions{ + .max_bytes_per_token = max_bytes_per_token, + .max_chars_per_subtoken = max_chars_per_sub_token, + .suffix_indicator = suffix_indicator, + .use_unknown_token = static_cast<bool>(use_unknown_token), + .unknown_token = unknown_token, + .split_unknown_chars = static_cast<bool>(split_unknown_chars), + .delim_str = text::tokenizer::kDefaultDelimRe, + .include_delim_str = text::tokenizer::kDefaultIncludeDelimRe}); + + env->ReleaseStringUTFChars(jsuffix_indicator, raw_suffix_indicator); + env->ReleaseStringUTFChars(junknown_token, raw_unknown_token); + + return reinterpret_cast<jlong>(handle.release()); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeUnloadResource( // NOLINT + JNIEnv* env, jobject thiz, jlong handle) { + delete reinterpret_cast<BertTokenizer*>(handle); + return 0; +} + +extern "C" JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeTokenize( + JNIEnv* env, jobject thiz, jlong handle, jstring jtext) { + return nativeTokenize(env, handle, jtext); +} + +extern "C" JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_BertTokenizer_nativeConvertTokensToIds( // NOLINT + JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) { + return nativeConvertTokensToIds(env, handle, jtokens); +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc b/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc new file mode 100644 index 00000000..38aff880 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.cc @@ -0,0 +1,125 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h" + +#include <iostream> + +#include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" +#include "tensorflow_lite_support/cc/utils/common_utils.h" +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + +namespace { +constexpr char kStart[] = "<START>"; +constexpr char kPad[] = "<PAD>"; +constexpr char kUnknown[] = "<UNKNOWN>"; + +void buildIndexTokenMap( + const absl::node_hash_map<std::string, int>& token_index_map, + absl::node_hash_map<int, absl::string_view>* index_token_map) { + for (const auto& token : token_index_map) { + (*index_token_map)[token.second] = token.first; + } +} + +} // namespace + +// RE2::FindAndConsume requires the delim_re_ to have a matching group in order +// to capture the matched delimiter length. Surround the regex with a +// parenthesis to create a matching group, it's fine if the regex is already +// surrounded by parenthesis. +RegexTokenizer::RegexTokenizer(const std::string& regex_pattern, + const std::string& path_to_vocab) + : delim_re_{absl::Substitute("($0)", regex_pattern)}, + token_index_map_{utils::LoadVocabAndIndexFromFile(path_to_vocab)} { + buildIndexTokenMap(token_index_map_, &index_token_map_); +} + +RegexTokenizer::RegexTokenizer(const std::string& regex_pattern, + const char* vocab_buffer_data, + size_t vocab_buffer_size) + : delim_re_{absl::Substitute("($0)", regex_pattern)}, + token_index_map_{utils::LoadVocabAndIndexFromBuffer(vocab_buffer_data, + vocab_buffer_size)} { + buildIndexTokenMap(token_index_map_, &index_token_map_); +} + +TokenizerResult RegexTokenizer::Tokenize(const std::string& input) { + absl::string_view leftover(input.data()); + absl::string_view last_end = leftover; + + TokenizerResult result; + + // Keep looking for split points until we have reached the end of the input. + absl::string_view extracted_delim_token; + while (RE2::FindAndConsume(&leftover, delim_re_, &extracted_delim_token)) { + absl::string_view token(last_end.data(), + extracted_delim_token.data() - last_end.data()); + bool has_non_empty_token = token.length() > 0; + + last_end = leftover; + + // Mark the end of the previous token, only if there was something. + if (has_non_empty_token) { + result.subwords.push_back(std::string(token)); + } + } + + // Close the last token. + if (!leftover.empty()) { + result.subwords.push_back(std::string(leftover)); + } + + return result; +} + +bool RegexTokenizer::LookupId(absl::string_view key, int* result) const { + auto it = token_index_map_.find(key); + if (it == token_index_map_.end()) { + return false; + } + *result = it->second; + return true; +} + +bool RegexTokenizer::LookupWord(int vocab_id, absl::string_view* result) const { + auto it = index_token_map_.find(vocab_id); + if (it == index_token_map_.end()) { + return false; + } + *result = it->second; + return true; +} + +bool RegexTokenizer::GetStartToken(int* start_token) { + return LookupId(kStart, start_token); +} + +bool RegexTokenizer::GetPadToken(int* pad_token) { + return LookupId(kPad, pad_token); +} + +bool RegexTokenizer::GetUnknownToken(int* unknown_token) { + return LookupId(kUnknown, unknown_token); +} + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h new file mode 100644 index 00000000..c53ae496 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_ + +#include "absl/container/node_hash_map.h" +#include "re2/re2.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" + +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + +// Tokenizer to load a vocabulary and split text by regular expressions. +class RegexTokenizer : public Tokenizer { + public: + explicit RegexTokenizer(const std::string& regex_pattern, + const std::string& path_to_vocab); + + explicit RegexTokenizer(const std::string& regex_pattern, + const char* vocab_buffer_data, + size_t vocab_buffer_size); + + TokenizerResult Tokenize(const std::string& input) override; + + bool LookupId(absl::string_view key, int* result) const override; + + bool LookupWord(int vocab_id, absl::string_view* result) const override; + + bool GetStartToken(int* start_token); + bool GetPadToken(int* pad_token); + bool GetUnknownToken(int* unknown_token); + + private: + RE2 delim_re_; + absl::node_hash_map<std::string, int> token_index_map_; + absl::node_hash_map<int, absl::string_view> index_token_map_; +}; + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_REGEX_TOKENIZER_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc new file mode 100644 index 00000000..88065e20 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_jni.cc @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include <cstring> +#include <string> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/str_split.h" +#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace support { + +using ::tflite::support::text::tokenizer::SentencePieceTokenizer; +using ::tflite::support::utils::GetMappedFileBuffer; + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeLoadResource( // NOLINT + JNIEnv* env, jobject obj, jobject model_buffer) { + auto model = GetMappedFileBuffer(env, model_buffer); + auto handle = + absl::make_unique<SentencePieceTokenizer>(model.data(), model.size()); + return reinterpret_cast<jlong>(handle.release()); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeUnloadResource( // NOLINT + JNIEnv* env, jobject obj, jlong handle) { + delete reinterpret_cast<SentencePieceTokenizer*>(handle); + return 0; +} + +extern "C" JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeTokenize( // NOLINT + JNIEnv* env, jobject thiz, jlong handle, jstring jtext) { + return nativeTokenize(env, handle, jtext); +} + +extern "C" JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_support_text_tokenizers_SentencePieceTokenizer_nativeConvertTokensToIds( // NOLINT + JNIEnv* env, jobject thiz, jlong handle, jobjectArray jtokens) { + return nativeConvertTokensToIds(env, handle, jtokens); +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h new file mode 100644 index 00000000..ed5d3da7 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ + +#include <fstream> +#include <string> +#include <vector> + +#include "src/sentencepiece_processor.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" + +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + +// SentencePiece tokenizer. Initialized with a model file. +class SentencePieceTokenizer : public Tokenizer { + public: + // Initialize the SentencePiece tokenizer from model file path. + explicit SentencePieceTokenizer(const std::string& path_to_model) { + CHECK_OK(sp_.Load(path_to_model)); + } + + explicit SentencePieceTokenizer(const char* spmodel_buffer_data, + size_t spmodel_buffer_size) { + absl::string_view buffer_binary(spmodel_buffer_data, spmodel_buffer_size); + CHECK_OK(sp_.LoadFromSerializedProto(buffer_binary)); + } + + // Perform tokenization, return tokenized results. + TokenizerResult Tokenize(const std::string& input) override { + TokenizerResult result; + std::vector<std::string>& subwords = result.subwords; + CHECK_OK(sp_.Encode(input, &subwords)); + return result; + } + + // Find the id of a string token. + bool LookupId(absl::string_view key, int* result) const override { + *result = sp_.PieceToId(key); + return true; + } + + // Find the string token of an id. + bool LookupWord(int vocab_id, absl::string_view* result) const override { + *result = sp_.IdToPiece(vocab_id); + return true; + } + + private: + sentencepiece::SentencePieceProcessor sp_; +}; + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_SENTENCEPIECE_TOKENIZER_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h b/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h new file mode 100644 index 00000000..c7545064 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_H_ + +#include <fstream> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" + +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + +struct TokenizerResult { + std::vector<std::string> subwords; +}; + +// Interface of general tokenizer. +class Tokenizer { + public: + // Perform tokenization to get tokenized results. + virtual TokenizerResult Tokenize(const std::string& input) = 0; + + // Find the id of a string token. + virtual bool LookupId(absl::string_view key, int* result) const = 0; + + // Find the string token from an id. + virtual bool LookupWord(int vocab_id, absl::string_view* result) const = 0; + + // Destructor. + virtual ~Tokenizer() = default; +}; + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc new file mode 100644 index 00000000..a72523be --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.cc @@ -0,0 +1,86 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h" + +namespace tflite { +namespace support { + +using ::tflite::support::text::tokenizer::Tokenizer; +using ::tflite::support::text::tokenizer::TokenizerResult; +using ::tflite::support::utils::CheckNotNull; +using ::tflite::support::utils::JStringToString; +using ::tflite::support::utils::kIllegalStateException; + +jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext) { + if (handle == 0) { + env->ThrowNew(env->FindClass(kIllegalStateException), + "Vocab not initialized!"); + return nullptr; + } + + Tokenizer* tokenizer = reinterpret_cast<Tokenizer*>(handle); + + // Get the tokenization results. + const TokenizerResult tokenize_result = + tokenizer->Tokenize(JStringToString(env, jtext)); + std::vector<std::string> subwords = tokenize_result.subwords; + + jclass string_class = CheckNotNull(env, env->FindClass("java/lang/String")); + jobjectArray result = CheckNotNull( + env, env->NewObjectArray(subwords.size(), string_class, nullptr)); + + for (int i = 0; i < subwords.size(); ++i) { + jstring text = CheckNotNull(env, env->NewStringUTF(subwords[i].data())); + if (env->ExceptionCheck()) { + return nullptr; + } + + env->SetObjectArrayElement(result, i, text); + } + + return result; +} + +jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle, + jobjectArray jtokens) { + if (handle == 0) { + env->ThrowNew(env->FindClass(kIllegalStateException), + "vocab not initialized!"); + return nullptr; + } + + Tokenizer* tokenizer = reinterpret_cast<Tokenizer*>(handle); + + // Get the token ids. + const int count = env->GetArrayLength(jtokens); + jintArray result = env->NewIntArray(count); + jint* jid_ptr = env->GetIntArrayElements(result, nullptr); + + for (int i = 0; i < count; i++) { + auto jstr = + reinterpret_cast<jstring>(env->GetObjectArrayElement(jtokens, i)); + const char* token = env->GetStringUTFChars(jstr, JNI_FALSE); + int id; + tokenizer->LookupId(token, &id); + jid_ptr[i] = id; + env->ReleaseStringUTFChars(jstr, token); + } + env->ReleaseIntArrayElements(result, jid_ptr, 0); + return result; +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h new file mode 100644 index 00000000..fc7285c6 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_jni_lib.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_JNI_LIB_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_JNI_LIB_H_ + +#include <jni.h> + +#include <string> + +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace support { + +jobjectArray nativeTokenize(JNIEnv* env, jlong handle, jstring jtext); + +jintArray nativeConvertTokensToIds(JNIEnv* env, jlong handle, + jobjectArray jtokens); + +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_JNI_LIB_H_ diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc new file mode 100644 index 00000000..3e81c478 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.cc @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h" + +#include "absl/status/status.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h" +#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + + +using ::tflite::ProcessUnit; +using ::tflite::SentencePieceTokenizerOptions; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +namespace { + +StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile( + const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>* + associated_files, + const tflite::metadata::ModelMetadataExtractor* metadata_extractor) { + if (associated_files == nullptr || associated_files->size() < 1 || + associated_files->Get(0)->name() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid vocab_file from input process unit.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + ASSIGN_OR_RETURN(absl::string_view vocab_buffer, + metadata_extractor->GetAssociatedFile( + associated_files->Get(0)->name()->str())); + return vocab_buffer; +} +} // namespace + +StatusOr<std::unique_ptr<Tokenizer>> CreateTokenizerFromProcessUnit( + const tflite::ProcessUnit* tokenizer_process_unit, + const tflite::metadata::ModelMetadataExtractor* metadata_extractor) { + if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "No metadata or input process unit found.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + switch (tokenizer_process_unit->options_type()) { + case ProcessUnitOptions_BertTokenizerOptions: { + const tflite::BertTokenizerOptions* options = + tokenizer_process_unit->options_as<tflite::BertTokenizerOptions>(); + ASSIGN_OR_RETURN(absl::string_view vocab_buffer, + CheckAndLoadFirstAssociatedFile(options->vocab_file(), + metadata_extractor)); + return absl::make_unique<BertTokenizer>(vocab_buffer.data(), + vocab_buffer.size()); + } + case ProcessUnitOptions_SentencePieceTokenizerOptions: { + const tflite::SentencePieceTokenizerOptions* options = + tokenizer_process_unit->options_as<SentencePieceTokenizerOptions>(); + ASSIGN_OR_RETURN(absl::string_view model_buffer, + CheckAndLoadFirstAssociatedFile( + options->sentencePiece_model(), metadata_extractor)); + // TODO(b/160647204): Extract sentence piece model vocabulary + return absl::make_unique<SentencePieceTokenizer>(model_buffer.data(), + model_buffer.size()); + } + case ProcessUnitOptions_RegexTokenizerOptions: { + const tflite::RegexTokenizerOptions* options = + tokenizer_process_unit->options_as<RegexTokenizerOptions>(); + ASSIGN_OR_RETURN(absl::string_view vocab_buffer, + CheckAndLoadFirstAssociatedFile(options->vocab_file(), + metadata_extractor)); + if (options->delim_regex_pattern() == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid delim_regex_pattern from input process unit.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + std::unique_ptr<RegexTokenizer> regex_tokenizer = + absl::make_unique<RegexTokenizer>( + options->delim_regex_pattern()->str(), vocab_buffer.data(), + vocab_buffer.size()); + + int unknown_token_id = 0; + if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "RegexTokenizer doesn't have <UNKNOWN> token.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + int pad_token_id = 0; + if (!regex_tokenizer->GetPadToken(&pad_token_id)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "RegexTokenizer doesn't have <PAD> token.", + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } + + return regex_tokenizer; + } + default: + return CreateStatusWithPayload( + absl::StatusCode::kNotFound, + absl::StrCat("Incorrect options_type:", + tokenizer_process_unit->options_type()), + TfLiteSupportStatus::kMetadataInvalidTokenizerError); + } +} + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite + diff --git a/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h new file mode 100644 index 00000000..2e50a799 --- /dev/null +++ b/tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_UTILS_H_ + +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace text { +namespace tokenizer { + + +// Create a Tokenizer from model metadata by extracting +tflite::support::StatusOr<std::unique_ptr<Tokenizer>> +CreateTokenizerFromProcessUnit( + const tflite::ProcessUnit* tokenizer_process_unit, + const tflite::metadata::ModelMetadataExtractor* metadata_extractor); + +} // namespace tokenizer +} // namespace text +} // namespace support +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_LITE_SUPPORT_CC_TEXT_TOKENIZERS_TOKENIZER_UTILS_H_ diff --git a/tensorflow_lite_support/cc/utils/BUILD b/tensorflow_lite_support/cc/utils/BUILD new file mode 100644 index 00000000..07c832f2 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/BUILD @@ -0,0 +1,32 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "jni_utils", + srcs = [ + "jni_utils.cc", + ], + hdrs = [ + "jni_utils.h", + ], + deps = [ + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/java/jni", + ], +) + +cc_library( + name = "common_utils", + srcs = [ + "common_utils.cc", + ], + hdrs = [ + "common_utils.h", + ], + deps = [ + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow_lite_support/cc/utils/common_utils.cc b/tensorflow_lite_support/cc/utils/common_utils.cc new file mode 100644 index 00000000..61996f47 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/common_utils.cc @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/utils/common_utils.h" + +#include <fstream> + +#include "absl/strings/str_split.h" + +namespace tflite { +namespace support { +namespace utils { +namespace { +struct membuf : std::streambuf { + membuf(char* begin, char* end) { this->setg(begin, begin, end); } +}; + +void ReadIStreamLineByLine( + std::istream* istream, + const std::function<void(std::string)>& line_processor) { + std::string str; + while (std::getline(*istream, str)) { + if (!str.empty()) { + line_processor(str); + } + } +} + +absl::node_hash_map<std::string, int> ReadIStreamLineSplits( + std::istream* istream) { + absl::node_hash_map<std::string, int> vocab_index_map; + std::string str; + ReadIStreamLineByLine(istream, [&vocab_index_map](const std::string& str) { + std::vector<std::string> v = absl::StrSplit(str, ' '); + vocab_index_map[v[0]] = std::stoi(v[1]); + }); + return vocab_index_map; +} + +std::vector<std::string> ReadIStreamByLine(std::istream* istream) { + std::vector<std::string> vocab_from_file; + std::string str; + + ReadIStreamLineByLine(istream, [&vocab_from_file](const std::string& str) { + vocab_from_file.push_back(str); + }); + return vocab_from_file; +} + +} // namespace + +std::vector<std::string> LoadVocabFromFile(const std::string& path_to_vocab) { + std::vector<std::string> vocab_from_file; + std::ifstream in(path_to_vocab.c_str()); + return ReadIStreamByLine(&in); +} + +std::vector<std::string> LoadVocabFromBuffer(const char* vocab_buffer_data, + const size_t vocab_buffer_size) { + membuf sbuf(const_cast<char*>(vocab_buffer_data), + const_cast<char*>(vocab_buffer_data + vocab_buffer_size)); + std::istream in(&sbuf); + return ReadIStreamByLine(&in); +} + +absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile( + const std::string& path_to_vocab) { + absl::node_hash_map<std::string, int> vocab_index_map; + std::ifstream in(path_to_vocab.c_str()); + return ReadIStreamLineSplits(&in); +} + +absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer( + const char* vocab_buffer_data, const size_t vocab_buffer_size) { + membuf sbuf(const_cast<char*>(vocab_buffer_data), + const_cast<char*>(vocab_buffer_data + vocab_buffer_size)); + absl::node_hash_map<std::string, int> vocab_index_map; + std::istream in(&sbuf); + return ReadIStreamLineSplits(&in); +} + +} // namespace utils +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/utils/common_utils.h b/tensorflow_lite_support/cc/utils/common_utils.h new file mode 100644 index 00000000..36232230 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/common_utils.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_UTILS_COMMON_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_UTILS_COMMON_UTILS_H_ + +#include <string> +#include <vector> + +#include "absl/container/node_hash_map.h" + +namespace tflite { +namespace support { +namespace utils { + +// Read a vocab file with one vocabulary on each line, create a vector of +// strings. +std::vector<std::string> LoadVocabFromFile(const std::string& path_to_vocab); + +// read a vocab buffer with one vocab one each line, create a vector of strings +std::vector<std::string> LoadVocabFromBuffer(const char* vocab_buffer_data, + const size_t vocab_buffer_size); + +// Read a vocab file with one vocabulary and its corresponding index on each +// line separated by space, create a map of <vocab, index>. +absl::node_hash_map<std::string, int> LoadVocabAndIndexFromFile( + const std::string& path_to_vocab); + +// Read a vocab buffer with one vocabulary and its corresponding index on each +// line separated by space, create a map of <vocab, index>. +absl::node_hash_map<std::string, int> LoadVocabAndIndexFromBuffer( + const char* vocab_buffer_data, const size_t vocab_buffer_size); +} // namespace utils +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_UTILS_COMMON_UTILS_H_ diff --git a/tensorflow_lite_support/cc/utils/jni_utils.cc b/tensorflow_lite_support/cc/utils/jni_utils.cc new file mode 100644 index 00000000..25cf3266 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/jni_utils.cc @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +#include <string.h> + +namespace tflite { +namespace support { +namespace utils { + +std::string JStringToString(JNIEnv* env, jstring jstr) { + if (jstr == nullptr) { + return std::string(); + } + const char* cstring = env->GetStringUTFChars(jstr, nullptr); + std::string result(cstring); + env->ReleaseStringUTFChars(jstr, cstring); + return result; +} + +std::vector<std::string> StringListToVector(JNIEnv* env, jobject list_object) { + jobject j_iterator = env->CallObjectMethod( + list_object, env->GetMethodID(env->GetObjectClass(list_object), + "iterator", "()Ljava/util/Iterator;")); + std::vector<std::string> result; + jmethodID has_next = + env->GetMethodID(env->GetObjectClass(j_iterator), "hasNext", "()Z"); + jmethodID get_next = env->GetMethodID(env->GetObjectClass(j_iterator), "next", + "()Ljava/lang/Object;"); + while (env->CallBooleanMethod(j_iterator, has_next)) { + jstring jstr = + static_cast<jstring>(env->CallObjectMethod(j_iterator, get_next)); + const char* raw_str = env->GetStringUTFChars(jstr, JNI_FALSE); + result.emplace_back(std::string(raw_str)); + env->ReleaseStringUTFChars(jstr, raw_str); + } + return result; +} + +absl::string_view GetMappedFileBuffer(JNIEnv* env, const jobject& file_buffer) { + return absl::string_view( + static_cast<char*>(env->GetDirectBufferAddress(file_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(file_buffer))); +} + +jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes) { + jbyteArray ret = env->NewByteArray(num_bytes); + env->SetByteArrayRegion(ret, 0, num_bytes, data); + + return ret; +} + +void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + const size_t max_msg_len = 512; + auto* message = static_cast<char*>(malloc(max_msg_len)); + if (message && (vsnprintf(message, max_msg_len, fmt, args) >= 0)) { + ThrowExceptionWithMessage(env, clazz, message); + } else { + ThrowExceptionWithMessage(env, clazz, ""); + } + if (message) { + free(message); + } + va_end(args); +} + +void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz, + const char* message) { + jclass e_class = env->FindClass(clazz); + if (strcmp(clazz, kAssertionError) == 0) { + // AssertionError cannot use ThrowNew in Java 7 + jmethodID constructor = + env->GetMethodID(e_class, "<init>", "(Ljava/lang/Object;)V"); + jstring jstr_message = env->NewStringUTF(message); + jobject e_object = env->NewObject(e_class, constructor, + static_cast<jobject>(jstr_message)); + env->Throw(static_cast<jthrowable>(e_object)); + return; + } + env->ThrowNew(e_class, message); +} + +} // namespace utils +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/cc/utils/jni_utils.h b/tensorflow_lite_support/cc/utils/jni_utils.h new file mode 100644 index 00000000..4a4aae46 --- /dev/null +++ b/tensorflow_lite_support/cc/utils/jni_utils.h @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_UTILS_JNI_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_UTILS_JNI_UTILS_H_ + +#include <jni.h> + +#include <functional> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" + +namespace tflite { +namespace support { +namespace utils { + +const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; +const char kIllegalStateException[] = "java/lang/IllegalStateException"; +const char kNullPointerException[] = "java/lang/NullPointerException"; +const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException"; +const char kUnsupportedOperationException[] = + "java/lang/UnsupportedOperationException"; +const char kAssertionError[] = "java/lang/AssertionError"; + +constexpr int kInvalidPointer = 0; + +// Check if t is nullptr, throw IllegalStateException if it is. +// Used to verify different types of jobjects are correctly created from jni. +template <typename T> +T CheckNotNull(JNIEnv* env, T&& t) { + if (t == nullptr) { + env->ThrowNew(env->FindClass(kIllegalStateException), ""); + return nullptr; + } + return std::forward<T>(t); +} + +// Converts a std::vector<T> into a Java ArrayList using a converter, which +// processes a single element in the vector before adding it to the ArrayList. +template <typename T> +jobject ConvertVectorToArrayList(JNIEnv* env, const std::vector<T>& results, + std::function<jobject(T)> converter) { + jclass array_list_class = env->FindClass("java/util/ArrayList"); + jmethodID array_list_ctor = + env->GetMethodID(array_list_class, "<init>", "(I)V"); + jint initial_capacity = static_cast<jint>(results.size()); + jobject array_list_object = + env->NewObject(array_list_class, array_list_ctor, initial_capacity); + jmethodID array_list_add_method = + env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); + + for (const auto& ans : results) { + env->CallBooleanMethod(array_list_object, array_list_add_method, + converter(ans)); + } + return array_list_object; +} + +std::string JStringToString(JNIEnv* env, jstring jstr); + +std::vector<std::string> StringListToVector(JNIEnv* env, jobject list_object); + +// Gets a mapped file buffer from a java object representing a file. +absl::string_view GetMappedFileBuffer(JNIEnv* env, const jobject& file_buffer); + +// Creates a Java byte array object based on the input data. +jbyteArray CreateByteArray(JNIEnv* env, const jbyte* data, int num_bytes); + +void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...); + +void ThrowExceptionWithMessage(JNIEnv* env, const char* clazz, + const char* message); + +} // namespace utils +} // namespace support +} // namespace tflite +#endif // TENSORFLOW_LITE_SUPPORT_CC_UTILS_JNI_UTILS_H_ diff --git a/tensorflow_lite_support/codegen/BUILD b/tensorflow_lite_support/codegen/BUILD new file mode 100644 index 00000000..b224f987 --- /dev/null +++ b/tensorflow_lite_support/codegen/BUILD @@ -0,0 +1,86 @@ +# The tools for generating wrapper classes for a TFLite model with metadata. + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "utils", + srcs = [ + "utils.cc", + ], + hdrs = [ + "utils.h", + ], + deps = [ + ], +) + +cc_library( + name = "code_generator", + srcs = [ + "code_generator.cc", + ], + hdrs = [ + "code_generator.h", + ], + deps = [ + ":utils", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + ], +) + +cc_library( + name = "metadata_helper", + srcs = [ + "metadata_helper.cc", + ], + hdrs = [ + "metadata_helper.h", + ], + deps = [ + ":utils", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_library( + name = "android_java_generator", + srcs = [ + "android_java_generator.cc", + ], + hdrs = [ + "android_java_generator.h", + ], + deps = [ + ":code_generator", + ":metadata_helper", + ":utils", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "code_generator_test", + size = "small", + srcs = ["code_generator_test.cc"], + data = ["//tensorflow_lite_support/metadata:metadata_schema.fbs"], + deps = [ + ":code_generator", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + deps = [ + ":utils", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow_lite_support/codegen/README.md b/tensorflow_lite_support/codegen/README.md new file mode 100644 index 00000000..d457edd1 --- /dev/null +++ b/tensorflow_lite_support/codegen/README.md @@ -0,0 +1,13 @@ +# TensorFlow Lite Android Wrapper Code Generator + +For TensorFlow Lite model enhanced with [metadata](https://www.tensorflow.org/lite/convert/metadata.md), +developers can use the TensorFlow Lite Android wrapper code generator to create +platform specific wrapper code. The wrapper code removes the need to interact +directly with `ByteBuffer`. Instead, developers can interact with the TensorFlow +Lite model with typed objects such as `Bitmap` and `Rect`. + +The usefulness of the code generator depend on the completeness of the +TensorFlow Lite model's metadata entry. Refer to the `<Codegen usage>` section +under relevant fields in +[metadata_schema.fbs](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/metadata/metadata_schema.fbs), +to see how the codegen tool parses each field. diff --git a/tensorflow_lite_support/codegen/android_java_generator.cc b/tensorflow_lite_support/codegen/android_java_generator.cc new file mode 100644 index 00000000..097119f1 --- /dev/null +++ b/tensorflow_lite_support/codegen/android_java_generator.cc @@ -0,0 +1,1017 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains the logic of android model wrapper generation. +// +// At the beginning is the helper functions handling metadata and code writer. +// +// Codes are generated in every `Generate{FOO}` functions. Gradle and Manifest +// files are simple. The wrapper file generation is a bit complex so we divided +// it into several sub-functions. +// +// The structure of the wrapper file looks like: +// +// [ imports ] +// [ class ] +// [ inner "Outputs" class ] +// [ innner "Metadata" class ] +// [ APIs ] ( including ctors, public APIs and private APIs ) +// +// We tried to mostly write it in a "template-generation" way. `CodeWriter` does +// the job as a template renderer. To avoid repeatedly setting the token values, +// helper functions `SetCodeWriterWith{Foo}Info` set the token values with info +// structures (`TensorInfo` and `ModelInfo`) - the Info structures are +// intermediate datastructures between Metadata (represented in Flatbuffers) and +// generated code. + +#include "tensorflow_lite_support/codegen/android_java_generator.h" + +#include <ctype.h> + +#include <algorithm> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow_lite_support/codegen/code_generator.h" +#include "tensorflow_lite_support/codegen/metadata_helper.h" +#include "tensorflow_lite_support/codegen/utils.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +namespace { + +using details_android_java::ModelInfo; +using details_android_java::TensorInfo; + +// Helper class to organize the C++ code block as a generated code block. +// Using ctor and dtor to simulate an enter/exit schema like `with` in Python. +class AsBlock { + public: + AsBlock(CodeWriter* code_writer, const std::string& before, + bool trailing_blank_line = false) + : code_writer_(code_writer), trailing_blank_line_(trailing_blank_line) { + code_writer_->AppendNoNewLine(before); + code_writer_->Append(" {"); + code_writer_->Indent(); + } + ~AsBlock() { + code_writer_->Outdent(); + code_writer_->Append("}"); + if (trailing_blank_line_) { + code_writer_->NewLine(); + } + } + + private: + CodeWriter* code_writer_; + bool trailing_blank_line_; +}; + +// Declare the functions first, so that the functions can follow a logical +// order. +bool GenerateWrapperClass(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperImports(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperInputs(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperOutputs(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperMetadata(CodeWriter*, const ModelInfo&, ErrorReporter*); +bool GenerateWrapperAPI(CodeWriter*, const ModelInfo&, ErrorReporter*); + +std::string GetModelVersionedName(const ModelMetadata* metadata) { + std::string model_name = "MyModel"; + if (metadata->name() != nullptr && !(metadata->name()->str().empty())) { + model_name = metadata->name()->str(); + } + std::string model_version = "unknown"; + if (metadata->version() != nullptr && !(metadata->version()->str().empty())) { + model_version = metadata->version()->str(); + } + return model_name + " (Version: " + model_version + ")"; +} + +TensorInfo CreateTensorInfo(const TensorMetadata* metadata, + const std::string& name, bool is_input, int index, + ErrorReporter* err) { + TensorInfo tensor_info; + std::string tensor_identifier = is_input ? "input" : "output"; + tensor_identifier += " " + std::to_string(index); + tensor_info.associated_axis_label_index = FindAssociatedFile( + metadata, AssociatedFileType_TENSOR_AXIS_LABELS, tensor_identifier, err); + tensor_info.associated_value_label_index = FindAssociatedFile( + metadata, AssociatedFileType_TENSOR_VALUE_LABELS, tensor_identifier, err); + if (is_input && (tensor_info.associated_axis_label_index >= 0 || + tensor_info.associated_value_label_index >= 0)) { + err->Warning( + "Found label file on input tensor (%s). Label file for input " + "tensor is not supported yet. The " + "file will be ignored.", + tensor_identifier.c_str()); + } + if (tensor_info.associated_axis_label_index >= 0 && + tensor_info.associated_value_label_index >= 0) { + err->Warning( + "Found both axis label file and value label file for tensor (%s), " + "which is not supported. Only the axis label file will be used.", + tensor_identifier.c_str()); + } + tensor_info.is_input = is_input; + tensor_info.name = SnakeCaseToCamelCase(name); + tensor_info.upper_camel_name = tensor_info.name; + tensor_info.upper_camel_name[0] = toupper(tensor_info.upper_camel_name[0]); + tensor_info.normalization_unit = + FindNormalizationUnit(metadata, tensor_identifier, err); + if (metadata->content() != nullptr && + metadata->content()->content_properties() != nullptr) { + // Enter tensor wrapper type inferring + if (metadata->content()->content_properties_type() == + ContentProperties_ImageProperties) { + if (metadata->content() + ->content_properties_as_ImageProperties() + ->color_space() == ColorSpaceType_RGB) { + tensor_info.content_type = "image"; + tensor_info.wrapper_type = "TensorImage"; + tensor_info.processor_type = "ImageProcessor"; + return tensor_info; + } else { + err->Warning( + "Found Non-RGB image on tensor (%s). Codegen currently does not " + "support it, and regard it as a plain numeric tensor.", + tensor_identifier.c_str()); + } + } + } + tensor_info.content_type = "tensor"; + tensor_info.wrapper_type = "TensorBuffer"; + tensor_info.processor_type = "TensorProcessor"; + return tensor_info; +} + +ModelInfo CreateModelInfo(const ModelMetadata* metadata, + const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path, + ErrorReporter* err) { + ModelInfo model_info; + if (!CodeGenerator::VerifyMetadata(metadata, err)) { + // TODO(b/150116380): Create dummy model info. + err->Error("Validating metadata failed."); + return model_info; + } + model_info.package_name = package_name; + model_info.model_class_name = model_class_name; + model_info.model_asset_path = model_asset_path; + model_info.model_versioned_name = GetModelVersionedName(metadata); + const auto* graph = metadata->subgraph_metadata()->Get(0); + auto names = CodeGenerator::NameInputsAndOutputs( + graph->input_tensor_metadata(), graph->output_tensor_metadata()); + std::vector<std::string> input_tensor_names = std::move(names.first); + std::vector<std::string> output_tensor_names = std::move(names.second); + + for (int i = 0; i < input_tensor_names.size(); i++) { + model_info.inputs.push_back( + CreateTensorInfo(graph->input_tensor_metadata()->Get(i), + input_tensor_names[i], true, i, err)); + if (i < input_tensor_names.size() - 1) { + model_info.inputs_list += ", "; + model_info.input_type_param_list += ", "; + } + model_info.inputs_list += model_info.inputs[i].name; + model_info.input_type_param_list += + model_info.inputs[i].wrapper_type + " " + model_info.inputs[i].name; + } + for (int i = 0; i < output_tensor_names.size(); i++) { + model_info.outputs.push_back( + CreateTensorInfo(graph->output_tensor_metadata()->Get(i), + output_tensor_names[i], false, i, err)); + if (i < output_tensor_names.size() - 1) { + model_info.postprocessor_type_param_list += ", "; + model_info.postprocessors_list += ", "; + } + model_info.postprocessors_list += + model_info.outputs[i].name + "Postprocessor"; + model_info.postprocessor_type_param_list += + model_info.outputs[i].processor_type + " " + + model_info.outputs[i].name + "Postprocessor"; + } + return model_info; +} + +void SetCodeWriterWithTensorInfo(CodeWriter* code_writer, + const TensorInfo& tensor_info) { + code_writer->SetTokenValue("NAME", tensor_info.name); + code_writer->SetTokenValue("NAME_U", tensor_info.upper_camel_name); + code_writer->SetTokenValue("CONTENT_TYPE", tensor_info.content_type); + code_writer->SetTokenValue("WRAPPER_TYPE", tensor_info.wrapper_type); + std::string wrapper_name = tensor_info.wrapper_type; + wrapper_name[0] = tolower(wrapper_name[0]); + code_writer->SetTokenValue("WRAPPER_NAME", wrapper_name); + code_writer->SetTokenValue("PROCESSOR_TYPE", tensor_info.processor_type); + code_writer->SetTokenValue("NORMALIZATION_UNIT", + std::to_string(tensor_info.normalization_unit)); + code_writer->SetTokenValue( + "ASSOCIATED_AXIS_LABEL_INDEX", + std::to_string(tensor_info.associated_axis_label_index)); + code_writer->SetTokenValue( + "ASSOCIATED_VALUE_LABEL_INDEX", + std::to_string(tensor_info.associated_value_label_index)); +} + +void SetCodeWriterWithModelInfo(CodeWriter* code_writer, + const ModelInfo& model_info) { + code_writer->SetTokenValue("PACKAGE", model_info.package_name); + code_writer->SetTokenValue("MODEL_PATH", model_info.model_asset_path); + code_writer->SetTokenValue("MODEL_CLASS_NAME", model_info.model_class_name); + // Extra info, half generated. + code_writer->SetTokenValue("INPUT_TYPE_PARAM_LIST", + model_info.input_type_param_list); + code_writer->SetTokenValue("INPUTS_LIST", model_info.inputs_list); + code_writer->SetTokenValue("POSTPROCESSORS_LIST", + model_info.postprocessors_list); + code_writer->SetTokenValue("POSTPROCESSOR_TYPE_PARAM_LIST", + model_info.postprocessor_type_param_list); +} + +constexpr char JAVA_DEFAULT_PACKAGE[] = "default"; + +std::string ConvertPackageToPath(const std::string& package) { + if (package == JAVA_DEFAULT_PACKAGE) { + return ""; + } + std::string path = package; + std::replace(path.begin(), path.end(), '.', '/'); + return path; +} + +bool IsImageUsed(const ModelInfo& model) { + for (const auto& input : model.inputs) { + if (input.content_type == "image") { + return true; + } + } + for (const auto& output : model.outputs) { + if (output.content_type == "image") { + return true; + } + } + return false; +} + +// The following functions generates the wrapper Java code for a model. + +bool GenerateWrapperFileContent(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("// Generated by TFLite Support."); + code_writer->Append("package {{PACKAGE}};"); + code_writer->NewLine(); + + if (!GenerateWrapperImports(code_writer, model, err)) { + err->Error("Fail to generate imports for wrapper class."); + return false; + } + if (!GenerateWrapperClass(code_writer, model, err)) { + err->Error("Fail to generate wrapper class."); + return false; + } + code_writer->NewLine(); + return true; +} + +bool GenerateWrapperImports(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + const std::string support_pkg = "org.tensorflow.lite.support."; + std::vector<std::string> imports{ + "android.content.Context", + "java.io.IOException", + "java.nio.ByteBuffer", + "java.nio.FloatBuffer", + "java.util.Arrays", + "java.util.HashMap", + "java.util.List", + "java.util.Map", + "org.tensorflow.lite.DataType", + "org.tensorflow.lite.Tensor", + "org.tensorflow.lite.Tensor.QuantizationParams", + support_pkg + "common.FileUtil", + support_pkg + "common.TensorProcessor", + support_pkg + "common.ops.CastOp", + support_pkg + "common.ops.DequantizeOp", + support_pkg + "common.ops.NormalizeOp", + support_pkg + "common.ops.QuantizeOp", + support_pkg + "label.Category", + support_pkg + "label.TensorLabel", + support_pkg + "metadata.MetadataExtractor", + support_pkg + "metadata.schema.NormalizationOptions", + support_pkg + "model.Model", + support_pkg + "tensorbuffer.TensorBuffer", + }; + if (IsImageUsed(model)) { + for (const auto& target : + {"image.ImageProcessor", "image.TensorImage", "image.ops.ResizeOp", + "image.ops.ResizeOp.ResizeMethod"}) { + imports.push_back(support_pkg + target); + } + } + + std::sort(imports.begin(), imports.end()); + for (const auto& target : imports) { + code_writer->SetTokenValue("TARGET", target); + code_writer->Append("import {{TARGET}};"); + } + code_writer->NewLine(); + return true; +} + +bool GenerateWrapperClass(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->SetTokenValue("MODEL_VERSIONED_NAME", + model.model_versioned_name); + code_writer->Append( + R"(/** Wrapper class of model {{MODEL_VERSIONED_NAME}} */)"); + const auto code_block = + AsBlock(code_writer, "public class {{MODEL_CLASS_NAME}}"); + code_writer->Append(R"(private final Metadata metadata; +private final Model model; +private static final String MODEL_NAME = "{{MODEL_PATH}}";)"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Preprocessor;"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append("private {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;"); + } + code_writer->NewLine(); + if (!GenerateWrapperOutputs(code_writer, model, err)) { + err->Error("Failed to generate output classes"); + return false; + } + code_writer->NewLine(); + if (!GenerateWrapperMetadata(code_writer, model, err)) { + err->Error("Failed to generate the metadata class"); + return false; + } + code_writer->NewLine(); + if (!GenerateWrapperAPI(code_writer, model, err)) { + err->Error("Failed to generate the common APIs"); + return false; + } + return true; +} + +bool GenerateWrapperOutputs(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append("/** Output wrapper of {@link {{MODEL_CLASS_NAME}}} */"); + auto class_block = AsBlock(code_writer, "public static class Outputs"); + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append("private final {{WRAPPER_TYPE}} {{NAME}};"); + if (tensor.associated_axis_label_index >= 0) { + code_writer->Append("private final List<String> {{NAME}}Labels;"); + } + code_writer->Append( + "private final {{PROCESSOR_TYPE}} {{NAME}}Postprocessor;"); + } + // Getters + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->NewLine(); + if (tensor.associated_axis_label_index >= 0) { + if (tensor.content_type == "tensor") { + code_writer->Append( + R"(public List<Category> get{{NAME_U}}AsCategoryList() { + return new TensorLabel({{NAME}}Labels, postprocess{{NAME_U}}({{NAME}})).getCategoryList(); +})"); + } else { // image + err->Warning( + "Axis label for images is not supported. The labels will " + "be ignored."); + } + } else { // no label + code_writer->Append( + R"(public {{WRAPPER_TYPE}} get{{NAME_U}}As{{WRAPPER_TYPE}}() { + return postprocess{{NAME_U}}({{NAME}}); +})"); + } + } + code_writer->NewLine(); + { + const auto ctor_block = AsBlock( + code_writer, + "Outputs(Metadata metadata, {{POSTPROCESSOR_TYPE_PARAM_LIST}})"); + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + if (tensor.content_type == "image") { + code_writer->Append( + R"({{NAME}} = new TensorImage(metadata.get{{NAME_U}}Type()); +{{NAME}}.load(TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), metadata.get{{NAME_U}}Type()));)"); + } else { // FEATURE, UNKNOWN + code_writer->Append( + "{{NAME}} = " + "TensorBuffer.createFixedSize(metadata.get{{NAME_U}}Shape(), " + "metadata.get{{NAME_U}}Type());"); + } + if (tensor.associated_axis_label_index >= 0) { + code_writer->Append("{{NAME}}Labels = metadata.get{{NAME_U}}Labels();"); + } + code_writer->Append( + "this.{{NAME}}Postprocessor = {{NAME}}Postprocessor;"); + } + } + code_writer->NewLine(); + { + const auto get_buffer_block = + AsBlock(code_writer, "Map<Integer, Object> getBuffer()"); + code_writer->Append("Map<Integer, Object> outputs = new HashMap<>();"); + for (int i = 0; i < model.outputs.size(); i++) { + SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]); + code_writer->SetTokenValue("ID", std::to_string(i)); + code_writer->Append("outputs.put({{ID}}, {{NAME}}.getBuffer());"); + } + code_writer->Append("return outputs;"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->NewLine(); + { + auto processor_block = + AsBlock(code_writer, + "private {{WRAPPER_TYPE}} " + "postprocess{{NAME_U}}({{WRAPPER_TYPE}} {{WRAPPER_NAME}})"); + code_writer->Append( + "return {{NAME}}Postprocessor.process({{WRAPPER_NAME}});"); + } + } + return true; +} + +bool GenerateWrapperMetadata(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append( + "/** Metadata accessors of {@link {{MODEL_CLASS_NAME}}} */"); + const auto class_block = AsBlock(code_writer, "public static class Metadata"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"(private final int[] {{NAME}}Shape; +private final DataType {{NAME}}DataType; +private final QuantizationParams {{NAME}}QuantizationParams;)"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"(private final float[] {{NAME}}Mean; +private final float[] {{NAME}}Stddev;)"); + } + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"(private final int[] {{NAME}}Shape; +private final DataType {{NAME}}DataType; +private final QuantizationParams {{NAME}}QuantizationParams;)"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"(private final float[] {{NAME}}Mean; +private final float[] {{NAME}}Stddev;)"); + } + if (tensor.associated_axis_label_index >= 0 || + tensor.associated_value_label_index >= 0) { + code_writer->Append("private final List<String> {{NAME}}Labels;"); + } + } + code_writer->NewLine(); + { + const auto ctor_block = AsBlock( + code_writer, + "public Metadata(ByteBuffer buffer, Model model) throws IOException"); + code_writer->Append( + "MetadataExtractor extractor = new MetadataExtractor(buffer);"); + for (int i = 0; i < model.inputs.size(); i++) { + SetCodeWriterWithTensorInfo(code_writer, model.inputs[i]); + code_writer->SetTokenValue("ID", std::to_string(i)); + code_writer->Append( + R"(Tensor {{NAME}}Tensor = model.getInputTensor({{ID}}); +{{NAME}}Shape = {{NAME}}Tensor.shape(); +{{NAME}}DataType = {{NAME}}Tensor.dataType(); +{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)"); + if (model.inputs[i].normalization_unit >= 0) { + code_writer->Append( + R"(NormalizationOptions {{NAME}}NormalizationOptions = + (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions()); +FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer(); +{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()]; +{{NAME}}MeanBuffer.get({{NAME}}Mean); +FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer(); +{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()]; +{{NAME}}StddevBuffer.get({{NAME}}Stddev);)"); + } + } + for (int i = 0; i < model.outputs.size(); i++) { + SetCodeWriterWithTensorInfo(code_writer, model.outputs[i]); + code_writer->SetTokenValue("ID", std::to_string(i)); + code_writer->Append( + R"(Tensor {{NAME}}Tensor = model.getOutputTensor({{ID}}); +{{NAME}}Shape = {{NAME}}Tensor.shape(); +{{NAME}}DataType = {{NAME}}Tensor.dataType(); +{{NAME}}QuantizationParams = {{NAME}}Tensor.quantizationParams();)"); + if (model.outputs[i].normalization_unit >= 0) { + code_writer->Append( + R"(NormalizationOptions {{NAME}}NormalizationOptions = + (NormalizationOptions) extractor.getInputTensorMetadata({{ID}}).processUnits({{NORMALIZATION_UNIT}}).options(new NormalizationOptions()); +FloatBuffer {{NAME}}MeanBuffer = {{NAME}}NormalizationOptions.meanAsByteBuffer().asFloatBuffer(); +{{NAME}}Mean = new float[{{NAME}}MeanBuffer.limit()]; +{{NAME}}MeanBuffer.get({{NAME}}Mean); +FloatBuffer {{NAME}}StddevBuffer = {{NAME}}NormalizationOptions.stdAsByteBuffer().asFloatBuffer(); +{{NAME}}Stddev = new float[{{NAME}}StddevBuffer.limit()]; +{{NAME}}StddevBuffer.get({{NAME}}Stddev);)"); + } + if (model.outputs[i].associated_axis_label_index >= 0) { + code_writer->Append(R"(String {{NAME}}LabelsFileName = + extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_AXIS_LABEL_INDEX}}).name(); +{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)"); + } else if (model.outputs[i].associated_value_label_index >= 0) { + code_writer->Append(R"(String {{NAME}}LabelsFileName = + extractor.getOutputTensorMetadata({{ID}}).associatedFiles({{ASSOCIATED_VALUE_LABEL_INDEX}}).name(); +{{NAME}}Labels = FileUtil.loadLabels(extractor.getAssociatedFile({{NAME}}LabelsFileName));)"); + } + } + } + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public int[] get{{NAME_U}}Shape() { + return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length); +} + +public DataType get{{NAME_U}}Type() { + return {{NAME}}DataType; +} + +public QuantizationParams get{{NAME_U}}QuantizationParams() { + return {{NAME}}QuantizationParams; +})"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"( +public float[] get{{NAME_U}}Mean() { + return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length); +} + +public float[] get{{NAME_U}}Stddev() { + return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length); +})"); + } + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public int[] get{{NAME_U}}Shape() { + return Arrays.copyOf({{NAME}}Shape, {{NAME}}Shape.length); +} + +public DataType get{{NAME_U}}Type() { + return {{NAME}}DataType; +} + +public QuantizationParams get{{NAME_U}}QuantizationParams() { + return {{NAME}}QuantizationParams; +})"); + if (tensor.normalization_unit >= 0) { + code_writer->Append(R"( +public float[] get{{NAME_U}}Mean() { + return Arrays.copyOf({{NAME}}Mean, {{NAME}}Mean.length); +} + +public float[] get{{NAME_U}}Stddev() { + return Arrays.copyOf({{NAME}}Stddev, {{NAME}}Stddev.length); +})"); + } + if (tensor.associated_axis_label_index >= 0 || + tensor.associated_value_label_index >= 0) { + code_writer->Append(R"( +public List<String> get{{NAME_U}}Labels() { + return {{NAME}}Labels; +})"); + } + } + return true; +} + +bool GenerateWrapperAPI(CodeWriter* code_writer, const ModelInfo& model, + ErrorReporter* err) { + code_writer->Append(R"(public Metadata getMetadata() { + return metadata; +} +)"); + code_writer->Append(R"(/** + * Creates interpreter and loads associated files if needed. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public static {{MODEL_CLASS_NAME}} newInstance(Context context) throws IOException { + return newInstance(context, MODEL_NAME, new Model.Options.Builder().build()); +} + +/** + * Creates interpreter and loads associated files if needed, but loading another model in the same + * input / output structure with the original one. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath) throws IOException { + return newInstance(context, modelPath, new Model.Options.Builder().build()); +} + +/** + * Creates interpreter and loads associated files if needed, with running options configured. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public static {{MODEL_CLASS_NAME}} newInstance(Context context, Model.Options runningOptions) throws IOException { + return newInstance(context, MODEL_NAME, runningOptions); +} + +/** + * Creates interpreter for a user-specified model. + * + * @throws IOException if an I/O error occurs when loading the tflite model. + */ +public static {{MODEL_CLASS_NAME}} newInstance(Context context, String modelPath, Model.Options runningOptions) throws IOException { + Model model = Model.createModel(context, modelPath, runningOptions); + Metadata metadata = new Metadata(model.getData(), model); + MyImageClassifier instance = new MyImageClassifier(model, metadata);)"); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append( + R"( instance.reset{{NAME_U}}Preprocessor( + instance.buildDefault{{NAME_U}}Preprocessor());)"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append( + R"( instance.reset{{NAME_U}}Postprocessor( + instance.buildDefault{{NAME_U}}Postprocessor());)"); + } + code_writer->Append(R"( return instance; +} +)"); + + // Pre, post processor setters + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public void reset{{NAME_U}}Preprocessor({{PROCESSOR_TYPE}} processor) { + {{NAME}}Preprocessor = processor; +})"); + } + for (const auto& tensor : model.outputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append(R"( +public void reset{{NAME_U}}Postprocessor({{PROCESSOR_TYPE}} processor) { + {{NAME}}Postprocessor = processor; +})"); + } + // Process method + code_writer->Append(R"( +/** Triggers the model. */ +public Outputs process({{INPUT_TYPE_PARAM_LIST}}) { + Outputs outputs = new Outputs(metadata, {{POSTPROCESSORS_LIST}}); + Object[] inputBuffers = preprocessInputs({{INPUTS_LIST}}); + model.run(inputBuffers, outputs.getBuffer()); + return outputs; +} + +/** Closes the model. */ +public void close() { + model.close(); +} +)"); + { + auto block = + AsBlock(code_writer, + "private {{MODEL_CLASS_NAME}}(Model model, Metadata metadata)"); + code_writer->Append(R"(this.model = model; +this.metadata = metadata;)"); + } + for (const auto& tensor : model.inputs) { + code_writer->NewLine(); + SetCodeWriterWithTensorInfo(code_writer, tensor); + auto block = AsBlock( + code_writer, + "private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Preprocessor()"); + code_writer->Append( + "{{PROCESSOR_TYPE}}.Builder builder = new " + "{{PROCESSOR_TYPE}}.Builder()"); + if (tensor.content_type == "image") { + code_writer->Append(R"( .add(new ResizeOp( + metadata.get{{NAME_U}}Shape()[1], + metadata.get{{NAME_U}}Shape()[2], + ResizeMethod.NEAREST_NEIGHBOR)))"); + } + if (tensor.normalization_unit >= 0) { + code_writer->Append( + R"( .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))"); + } + code_writer->Append( + R"( .add(new QuantizeOp( + metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(), + metadata.get{{NAME_U}}QuantizationParams().getScale())) + .add(new CastOp(metadata.get{{NAME_U}}Type())); +return builder.build();)"); + } + for (const auto& tensor : model.outputs) { + code_writer->NewLine(); + SetCodeWriterWithTensorInfo(code_writer, tensor); + auto block = AsBlock( + code_writer, + "private {{PROCESSOR_TYPE}} buildDefault{{NAME_U}}Postprocessor()"); + code_writer->AppendNoNewLine( + R"({{PROCESSOR_TYPE}}.Builder builder = new {{PROCESSOR_TYPE}}.Builder() + .add(new DequantizeOp( + metadata.get{{NAME_U}}QuantizationParams().getZeroPoint(), + metadata.get{{NAME_U}}QuantizationParams().getScale())))"); + if (tensor.normalization_unit >= 0) { + code_writer->AppendNoNewLine(R"( + .add(new NormalizeOp(metadata.get{{NAME_U}}Mean(), metadata.get{{NAME_U}}Stddev())))"); + } + code_writer->Append(R"(; +return builder.build();)"); + } + code_writer->NewLine(); + { + const auto block = + AsBlock(code_writer, + "private Object[] preprocessInputs({{INPUT_TYPE_PARAM_LIST}})"); + CodeWriter param_list_gen(err); + for (const auto& tensor : model.inputs) { + SetCodeWriterWithTensorInfo(code_writer, tensor); + code_writer->Append("{{NAME}} = {{NAME}}Preprocessor.process({{NAME}});"); + SetCodeWriterWithTensorInfo(¶m_list_gen, tensor); + param_list_gen.AppendNoNewLine("{{NAME}}.getBuffer(), "); + } + param_list_gen.Backspace(2); + code_writer->AppendNoNewLine("return new Object[] {"); + code_writer->AppendNoNewLine(param_list_gen.ToString()); + code_writer->Append("};"); + } + return true; +} + +bool GenerateBuildGradleContent(CodeWriter* code_writer, + const ModelInfo& model_info) { + code_writer->Append(R"(buildscript { + repositories { + google() + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:3.2.1' + } +} + +allprojects { + repositories { + google() + jcenter() + flatDir { + dirs 'libs' + } + } +} + +apply plugin: 'com.android.library' + +android { + compileSdkVersion 29 + defaultConfig { + targetSdkVersion 29 + versionCode 1 + versionName "1.0" + } + aaptOptions { + noCompress "tflite" + } + compileOptions { + sourceCompatibility = '1.8' + targetCompatibility = '1.8' + } + lintOptions { + abortOnError false + } +} + +configurations { + libMetadata +} + +dependencies { + libMetadata 'org.tensorflow:tensorflow-lite-support:0.0.0-experimental-metadata-monolithic' +} + +task downloadLibs(type: Sync) { + from configurations.libMetadata + into "$buildDir/libs" + rename 'tensorflow-lite-support-0.0.0-experimental-metadata-monolithic.jar', "tensorflow-lite-support-metadata.jar" +} + +preBuild.dependsOn downloadLibs + +dependencies { + compileOnly 'org.checkerframework:checker-qual:2.5.8' + api 'org.tensorflow:tensorflow-lite:0.0.0-nightly' + api 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly' + api files("$buildDir/libs/tensorflow-lite-support-metadata.jar") + implementation 'org.apache.commons:commons-compress:1.19' +})"); + return true; +} + +bool GenerateAndroidManifestContent(CodeWriter* code_writer, + const ModelInfo& model_info) { + code_writer->Append(R"(<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="{{PACKAGE}}"> +</manifest>)"); + return true; +} + +bool GenerateDocContent(CodeWriter* code_writer, const ModelInfo& model_info) { + code_writer->Append("# {{MODEL_CLASS_NAME}} Usage"); + // TODO(b/158651848) Generate imports for TFLS util types like TensorImage. + code_writer->AppendNoNewLine(R"( +``` +import {{PACKAGE}}.{{MODEL_CLASS_NAME}}; + +// 1. Initialize the Model +{{MODEL_CLASS_NAME}} model = null; + +try { + model = {{MODEL_CLASS_NAME}}.newInstance(context); // android.content.Context +} catch (IOException e) { + e.printStackTrace(); +} + +if (model != null) { + + // 2. Set the inputs)"); + for (const auto& t : model_info.inputs) { + SetCodeWriterWithTensorInfo(code_writer, t); + if (t.content_type == "image") { + code_writer->Append(R"( + // Prepare tensor "{{NAME}}" from a Bitmap with ARGB_8888 format. + Bitmap bitmap = ...; + TensorImage {{NAME}} = TensorImage.fromBitmap(bitmap); + // Alternatively, load the input tensor "{{NAME}}" from pixel values. + // Check out TensorImage documentation to load other image data structures. + // int[] pixelValues = ...; + // int[] shape = ...; + // TensorImage {{NAME}} = new TensorImage(); + // {{NAME}}.load(pixelValues, shape);)"); + } else { + code_writer->Append(R"( + // Prepare input tensor "{{NAME}}" from an array. + // Check out TensorBuffer documentation to load other data structures. + TensorBuffer {{NAME}} = ...; + int[] values = ...; + int[] shape = ...; + {{NAME}}.load(values, shape);)"); + } + } + code_writer->Append(R"( + // 3. Run the model + {{MODEL_CLASS_NAME}}.Outputs outputs = model.process({{INPUTS_LIST}});)"); + code_writer->Append(R"( + // 4. Retrieve the results)"); + for (const auto& t : model_info.outputs) { + SetCodeWriterWithTensorInfo(code_writer, t); + if (t.associated_axis_label_index >= 0) { + code_writer->SetTokenValue("WRAPPER_TYPE", "List<Category>"); + code_writer->Append( + " List<Category> {{NAME}} = " + "outputs.get{{NAME_U}}AsCategoryList();"); + } else { + code_writer->Append( + " {{WRAPPER_TYPE}} {{NAME}} = " + "outputs.get{{NAME_U}}As{{WRAPPER_TYPE}}();"); + } + } + code_writer->Append(R"(} +```)"); + return true; +} + +GenerationResult::File GenerateWrapperFile(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + const auto java_path = JoinPath(module_root, "src/main/java"); + const auto package_path = + JoinPath(java_path, ConvertPackageToPath(model_info.package_name)); + const auto file_path = + JoinPath(package_path, model_info.model_class_name + JAVA_EXT); + + CodeWriter code_writer(err); + code_writer.SetIndentString(" "); + SetCodeWriterWithModelInfo(&code_writer, model_info); + + if (!GenerateWrapperFileContent(&code_writer, model_info, err)) { + err->Error("Generating Java wrapper content failed."); + } + + const auto java_file = code_writer.ToString(); + return GenerationResult::File{file_path, java_file}; +} + +GenerationResult::File GenerateBuildGradle(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + const auto file_path = JoinPath(module_root, "build.gradle"); + CodeWriter code_writer(err); + SetCodeWriterWithModelInfo(&code_writer, model_info); + if (!GenerateBuildGradleContent(&code_writer, model_info)) { + err->Error("Generating build.gradle failed."); + } + const auto content = code_writer.ToString(); + return GenerationResult::File{file_path, content}; +} + +GenerationResult::File GenerateAndroidManifest(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + const auto file_path = JoinPath(module_root, "src/main/AndroidManifest.xml"); + CodeWriter code_writer(err); + SetCodeWriterWithModelInfo(&code_writer, model_info); + if (!GenerateAndroidManifestContent(&code_writer, model_info)) { + err->Error("Generating AndroidManifest.xml failed."); + } + return GenerationResult::File{file_path, code_writer.ToString()}; +} + +GenerationResult::File GenerateDoc(const std::string& module_root, + const ModelInfo& model_info, + ErrorReporter* err) { + std::string lower = model_info.model_class_name; + for (int i = 0; i < lower.length(); i++) { + lower[i] = std::tolower(lower[i]); + } + const auto file_path = JoinPath(module_root, lower + ".md"); + CodeWriter code_writer(err); + SetCodeWriterWithModelInfo(&code_writer, model_info); + if (!GenerateDocContent(&code_writer, model_info)) { + err->Error("Generating doc failed."); + } + return GenerationResult::File{file_path, code_writer.ToString()}; +} + +} // namespace + +AndroidJavaGenerator::AndroidJavaGenerator(const std::string& module_root) + : CodeGenerator(), module_root_(module_root) {} + +GenerationResult AndroidJavaGenerator::Generate( + const Model* model, const std::string& package_name, + const std::string& model_class_name, const std::string& model_asset_path) { + GenerationResult result; + if (model == nullptr) { + err_.Error( + "Cannot read model from the buffer. Codegen will generate nothing."); + return result; + } + const ModelMetadata* metadata = GetMetadataFromModel(model); + if (metadata == nullptr) { + err_.Error( + "Cannot find TFLite Metadata in the model. Codegen will generate " + "nothing."); + return result; + } + details_android_java::ModelInfo model_info = CreateModelInfo( + metadata, package_name, model_class_name, model_asset_path, &err_); + result.files.push_back(GenerateWrapperFile(module_root_, model_info, &err_)); + result.files.push_back(GenerateBuildGradle(module_root_, model_info, &err_)); + result.files.push_back( + GenerateAndroidManifest(module_root_, model_info, &err_)); + result.files.push_back(GenerateDoc(module_root_, model_info, &err_)); + return result; +} + +GenerationResult AndroidJavaGenerator::Generate( + const char* model_storage, const std::string& package_name, + const std::string& model_class_name, const std::string& model_asset_path) { + const Model* model = GetModel(model_storage); + return Generate(model, package_name, model_class_name, model_asset_path); +} + +std::string AndroidJavaGenerator::GetErrorMessage() { + return err_.GetMessage(); +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/codegen/android_java_generator.h b/tensorflow_lite_support/codegen/android_java_generator.h new file mode 100644 index 00000000..634ccf69 --- /dev/null +++ b/tensorflow_lite_support/codegen/android_java_generator.h @@ -0,0 +1,116 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow_lite_support/codegen/code_generator.h" +#include "tensorflow_lite_support/codegen/utils.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +namespace details_android_java { + +/// The intermediate data structure for generating code from TensorMetadata. +/// Should only be used as const reference when created. +struct TensorInfo { + std::string name; + std::string upper_camel_name; + std::string content_type; + std::string wrapper_type; + std::string processor_type; + bool is_input; + /// Optional. Set to -1 if not applicable. + int normalization_unit; + /// Optional. Set to -1 if associated_axis_label is empty. + int associated_axis_label_index; + /// Optional. Set to -1 if associated_value_label is empty. + int associated_value_label_index; +}; + +/// The intermediate data structure for generating code from ModelMetadata. +/// Should only be used as const reference when created. +struct ModelInfo { + std::string package_name; + std::string model_asset_path; + std::string model_class_name; + std::string model_versioned_name; + std::vector<TensorInfo> inputs; + std::vector<TensorInfo> outputs; + // Extra helper fields. For models with inputs "a", "b" and outputs "x", "y": + std::string input_type_param_list; + // e.g. "TensorImage a, TensorBuffer b" + std::string inputs_list; + // e.g. "a, b" + std::string postprocessor_type_param_list; + // e.g. "ImageProcessor xPostprocessor, TensorProcessor yPostprocessor" + std::string postprocessors_list; + // e.g. "xPostprocessor, yPostprocessor" +}; + +} // namespace details_android_java + +constexpr char JAVA_EXT[] = ".java"; + +/// Generates Android supporting codes and modules (in Java) based on TFLite +/// metadata. +class AndroidJavaGenerator : public CodeGenerator { + public: + /// Creates an AndroidJavaGenerator. + /// Args: + /// - module_root: The root of destination Java module. + explicit AndroidJavaGenerator(const std::string& module_root); + + /// Generates files. Returns the file paths and contents. + /// Args: + /// - model: The TFLite model with Metadata filled. + /// - package_name: The name of the Java package which generated classes + /// belong to. + /// - model_class_name: A readable name of the generated wrapper class, such + /// as "ImageClassifier", "MobileNetV2" or "MyModel". + /// - model_asset_path: The relevant path to the model file in the asset. + // TODO(b/141225157): Automatically generate model_class_name. + GenerationResult Generate(const Model* model, const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path); + + /// Generates files and returns the file paths and contents. + /// It's mostly identical with the previous one, but the model here is + /// provided as binary flatbuffer content without parsing. + GenerationResult Generate(const char* model_storage, + const std::string& package_name, + const std::string& model_class_name, + const std::string& model_asset_path); + + std::string GetErrorMessage(); + + private: + const std::string module_root_; + ErrorReporter err_; +}; + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_ diff --git a/tensorflow_lite_support/codegen/code_generator.cc b/tensorflow_lite_support/codegen/code_generator.cc new file mode 100644 index 00000000..1337708d --- /dev/null +++ b/tensorflow_lite_support/codegen/code_generator.cc @@ -0,0 +1,179 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/codegen/code_generator.h" + +#include <cctype> +#include <unordered_map> +#include <unordered_set> +#include <utility> + +#include "tensorflow_lite_support/codegen/utils.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +namespace { + +void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) { + auto& names = *names_ptr; + std::unordered_map<std::string, int> indexes; + std::unordered_map<std::string, int> first_appearance; + for (int i = 0; i < names.size(); i++) { + if (indexes.find(names[i]) == indexes.end()) { + indexes[names[i]] = 1; + first_appearance[names[i]] = i; + } else { + indexes[names[i]] += 1; + names[i].append(std::to_string(indexes[names[i]])); + } + } + for (const auto& it : first_appearance) { + const auto& name = it.first; + const auto i = it.second; + if (indexes[name] > 1) { + names[i].append("1"); + } + } +} + +} // namespace + +CodeGenerator::CodeGenerator() {} + +bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata, + ErrorReporter* err) { + if (metadata == nullptr) { + err->Error("Loading nullptr is not allowed"); + return false; + } + if (metadata->subgraph_metadata()->size() != 1) { + err->Error("Only exact 1 subgraph is supported"); + return false; + } + return true; +} + +std::pair<std::vector<std::string>, std::vector<std::string>> +CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs, + const TensorMetadataList* outputs) { + std::vector<std::string> input_names; + std::vector<std::string> output_names; + if (inputs != nullptr) { + input_names.reserve(inputs->size()); + for (const auto* tensor : *inputs) { + input_names.push_back(NameTensor(*tensor, "input")); + } + } + if (outputs != nullptr) { + output_names.reserve(outputs->size()); + for (const auto* tensor : *outputs) { + output_names.push_back(NameTensor(*tensor, "output")); + } + } + // Solve conflict + ResolveConflictedInputAndOutputNames(&input_names, &output_names); + return std::make_pair(input_names, output_names); +} + +std::string CodeGenerator::ConvertToValidName(const std::string& name) { + // lowercase all + std::string result = name; + for (int i = 0; i < result.size(); i++) { + result[i] = std::tolower(result[i]); + } + // replace all non-alpha or non-numeric with underscores, except underscore + // itself + for (int i = 0; i < result.size(); i++) { + if (result[i] != '_' && !std::isalnum(result[i])) { + result[i] = '_'; + } + } + // remove leading underscores + int leading_underscores = 0; + while (leading_underscores < result.size() && + result[leading_underscores] == '_') { + leading_underscores++; + } + result.erase(0, leading_underscores); + if (result.empty()) { + return ""; + } + // first char should be alpha + if (std::isalpha(result[0])) { + return result; + } + return "tensor_" + result; +} + +std::string CodeGenerator::NameTensor(const TensorMetadata& tensor, + const std::string& default_name) { + if (tensor.name() != nullptr && tensor.name()->size() > 0) { + // TODO(b/141225157) Validate tensor name. It should be in lower case. + auto suggested_name = ConvertToValidName(tensor.name()->str()); + if (!suggested_name.empty()) { + return suggested_name; + } + } + auto* content = tensor.content(); + if (content == nullptr || content->content_properties() == nullptr) { + return default_name; + } + switch (content->content_properties_type()) { + case ContentProperties_ImageProperties: + return "image"; + case ContentProperties_FeatureProperties: + return "feature"; + default: + return default_name; + } +} + +void CodeGenerator::ResolveConflictedInputAndOutputNames( + std::vector<std::string>* inputs, std::vector<std::string>* outputs) { + std::unordered_set<std::string> io_conflict; + auto& input_names = *inputs; + auto& output_names = *outputs; + for (const auto& input : input_names) { + if (io_conflict.find(input) != io_conflict.end()) { + continue; + } + for (const auto& output : output_names) { + if (input == output) { + io_conflict.insert(input); + break; + } + } + } + for (int i = 0; i < input_names.size(); i++) { + if (io_conflict.find(input_names[i]) != io_conflict.end()) { + input_names[i] = "input_" + input_names[i]; + } + } + for (int i = 0; i < output_names.size(); i++) { + if (io_conflict.find(output_names[i]) != io_conflict.end()) { + output_names[i] = "output_" + output_names[i]; + } + } + // 2. Second, add index if input[i] == input[j] + ResolveConflictedNamesByAddingIndex(&input_names); + ResolveConflictedNamesByAddingIndex(&output_names); +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/codegen/code_generator.h b/tensorflow_lite_support/codegen/code_generator.h new file mode 100644 index 00000000..b557773d --- /dev/null +++ b/tensorflow_lite_support/codegen/code_generator.h @@ -0,0 +1,80 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_CODE_GENERATOR_H_ +#define TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_CODE_GENERATOR_H_ + +#include <map> +#include <memory> +#include <sstream> +#include <string> + +#include "tensorflow_lite_support/codegen/utils.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +struct GenerationResult { + struct File { + std::string path; + std::string content; + }; + std::vector<File> files; +}; + +/// Defines language-independent codegen strategies, like class naming, .etc. +/// Should not be used directly. +class CodeGenerator { + public: + CodeGenerator(); + + using TensorMetadataList = + typename flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>; + + virtual ~CodeGenerator() {} + + // Strategies. + /// Names all the IO tensors. It's useful when they don't have names, or the + /// names have conflicts. We have to name every tensor for code generation. + // TODO(b/141225157): Add reserved keywords check. + static std::pair<std::vector<std::string>, std::vector<std::string>> + NameInputsAndOutputs(const TensorMetadataList* inputs, + const TensorMetadataList* outputs); + + /// Loads a metadata for code generation. + /// Returns false if the metadata is not good for generation. + static bool VerifyMetadata(const ModelMetadata* metadata, ErrorReporter* err); + + protected: + /// Converts a name into a valid form. Rules: + /// - lower all letters. + /// - replace all non alphabet nor numeric characters with underscores. + /// - remove prefix underscores. + /// - add prefix if the leading character is a number. + /// Returns empty string if not possible. + static std::string ConvertToValidName(const std::string& name); + static std::string NameTensor(const TensorMetadata& tensor, + const std::string& default_name); + static void ResolveConflictedInputAndOutputNames( + std::vector<std::string>* input, std::vector<std::string>* output); +}; + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_CODE_GENERATOR_H_ diff --git a/tensorflow_lite_support/codegen/code_generator_test.cc b/tensorflow_lite_support/codegen/code_generator_test.cc new file mode 100644 index 00000000..5e9d64a0 --- /dev/null +++ b/tensorflow_lite_support/codegen/code_generator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/codegen/code_generator.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tflite { +namespace support { +namespace codegen { +namespace { + +using ::testing::ElementsAreArray; + +class CodeGeneratorTest : public ::testing::Test { + public: + class TestingCodeGenerator : public CodeGenerator { + public: + explicit TestingCodeGenerator() : CodeGenerator() {} + + // Make tested method public. + static std::string ConvertToValidName(const std::string& name) { + return CodeGenerator::ConvertToValidName(name); + } + static void ResolveConflictedInputAndOutputNames( + std::vector<std::string>* input, std::vector<std::string>* output) { + CodeGenerator::ResolveConflictedInputAndOutputNames(input, output); + } + }; +}; + +TEST_F(CodeGeneratorTest, UpperCasesShouldLower) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("AlphaBetCOOL"), + "alphabetcool"); +} + +TEST_F(CodeGeneratorTest, NonAlphaNumShouldReplace) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("A+=B C\t"), "a__b_c_"); +} + +TEST_F(CodeGeneratorTest, NoLeadingUnderscore) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("+KAI Z"), "kai_z"); +} + +TEST_F(CodeGeneratorTest, NoLeadingNumbers) { + EXPECT_THAT(TestingCodeGenerator::ConvertToValidName("3000 Cool Tensors"), + "tensor_3000_cool_tensors"); +} + +TEST_F(CodeGeneratorTest, TestSimpleIONames) { + std::vector<std::string> inputs = {"image"}; + std::vector<std::string> outputs = {"output"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"image"})); + EXPECT_THAT(outputs, ElementsAreArray({"output"})); +} + +TEST_F(CodeGeneratorTest, TestIOConflict) { + std::vector<std::string> inputs = {"image"}; + std::vector<std::string> outputs = {"image"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"input_image"})); + EXPECT_THAT(outputs, ElementsAreArray({"output_image"})); +} + +TEST_F(CodeGeneratorTest, TestInternalConflict) { + std::vector<std::string> inputs = {"image", "image"}; + std::vector<std::string> outputs = {"output"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"image1", "image2"})); + EXPECT_THAT(outputs, ElementsAreArray({"output"})); +} + +TEST_F(CodeGeneratorTest, TestAllConflictNTo1) { + std::vector<std::string> inputs = {"image", "image"}; + std::vector<std::string> outputs = {"image"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, ElementsAreArray({"input_image1", "input_image2"})); + EXPECT_THAT(outputs, ElementsAreArray({"output_image"})); +} + +TEST_F(CodeGeneratorTest, TestAllConflict) { + std::vector<std::string> inputs = {"image", "audio", "image", "audio", + "audio"}; + std::vector<std::string> outputs = {"image", "image", "audio", "feature", + "feature"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, + ElementsAreArray({"input_image1", "input_audio1", "input_image2", + "input_audio2", "input_audio3"})); + EXPECT_THAT(outputs, + ElementsAreArray({"output_image1", "output_image2", + "output_audio", "feature1", "feature2"})); +} + +TEST_F(CodeGeneratorTest, TestAllConflictReversed) { + std::vector<std::string> inputs = {"image", "image", "audio", "feature", + "feature"}; + std::vector<std::string> outputs = {"image", "audio", "image", "audio", + "audio"}; + TestingCodeGenerator::ResolveConflictedInputAndOutputNames(&inputs, &outputs); + EXPECT_THAT(inputs, + ElementsAreArray({"input_image1", "input_image2", "input_audio", + "feature1", "feature2"})); + EXPECT_THAT(outputs, ElementsAreArray({"output_image1", "output_audio1", + "output_image2", "output_audio2", + "output_audio3"})); +} + +} // namespace +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/codegen/metadata_helper.cc b/tensorflow_lite_support/codegen/metadata_helper.cc new file mode 100644 index 00000000..00c97236 --- /dev/null +++ b/tensorflow_lite_support/codegen/metadata_helper.cc @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/codegen/metadata_helper.h" + +#include "tensorflow_lite_support/codegen/utils.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +constexpr char BUFFER_KEY[] = "TFLITE_METADATA"; +const ModelMetadata* GetMetadataFromModel(const Model* model) { + if (model == nullptr || model->metadata() == nullptr) { + return nullptr; + } + for (auto i = 0; i < model->metadata()->size(); i++) { + const auto* name = model->metadata()->Get(i)->name(); + if (name != nullptr && name->str() == BUFFER_KEY) { + const auto buffer_index = model->metadata()->Get(i)->buffer(); + if (model->buffers() == nullptr || + model->buffers()->size() <= buffer_index) { + continue; + } + const auto* buffer_vec = model->buffers()->Get(buffer_index)->data(); + if (buffer_vec == nullptr || buffer_vec->data() == nullptr) { + continue; + } + return GetModelMetadata(buffer_vec->data()); + } + } + return nullptr; +} + +int FindAssociatedFile(const TensorMetadata* metadata, + const AssociatedFileType file_type, + const std::string& tensor_identifier, + ErrorReporter* err) { + int result = -1; + if (metadata->associated_files() == nullptr || + metadata->associated_files()->size() == 0) { + return result; + } + for (int i = 0; i < metadata->associated_files()->size(); i++) { + const auto* file_metadata = metadata->associated_files()->Get(i); + if (file_metadata->type() == file_type) { + if (result >= 0) { + err->Warning( + "Multiple associated file of type %d found on tensor %s. Only the " + "first one will be used.", + file_type, tensor_identifier.c_str()); + continue; + } + result = i; + } + } + return result; +} + +int FindNormalizationUnit(const TensorMetadata* metadata, + const std::string& tensor_identifier, + ErrorReporter* err) { + int result = -1; + if (metadata->process_units() == nullptr || + metadata->process_units()->size() == 0) { + return result; + } + for (int i = 0; i < metadata->process_units()->size(); i++) { + const auto* process_uint = metadata->process_units()->Get(i); + if (process_uint->options_type() == + ProcessUnitOptions_NormalizationOptions) { + if (result >= 0) { + err->Warning( + "Multiple normalization unit found in tensor %s. Only the first " + "one will be effective.", + tensor_identifier.c_str()); + continue; + } + result = i; + } + } + return result; +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/codegen/metadata_helper.h b/tensorflow_lite_support/codegen/metadata_helper.h new file mode 100644 index 00000000..8e3dc6ab --- /dev/null +++ b/tensorflow_lite_support/codegen/metadata_helper.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_METADATA_HELPER_H_ +#define TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_METADATA_HELPER_H_ + +#include <string> + +#include "tensorflow_lite_support/codegen/utils.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace support { +namespace codegen { + +/// Parses a ModelMetadata out from a Model. The returned ModelMetadata's +/// lifetime is scoped by the model. +/// Returns nullptr if we cannot find any metadata. +const ModelMetadata* GetMetadataFromModel(const Model* model); + +/// Finds an associated file from a TensorMetadata of certain type. If there're +/// multiple files meet the criteria, only the first one is used. If there's no +/// file meets the criteria, -1 will be returned. +int FindAssociatedFile(const TensorMetadata* metadata, + const AssociatedFileType file_type, + const std::string& tensor_identifier, + ErrorReporter* err); + +/// Find the first normalization unit. If none, return -1. +int FindNormalizationUnit(const TensorMetadata* metadata, + const std::string& tensor_identifier, + ErrorReporter* err); + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_METADATA_HELPER_H_ diff --git a/tensorflow_lite_support/codegen/python/BUILD b/tensorflow_lite_support/codegen/python/BUILD new file mode 100644 index 00000000..ee4dcbbd --- /dev/null +++ b/tensorflow_lite_support/codegen/python/BUILD @@ -0,0 +1,37 @@ +load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "_pywrap_codegen", + srcs = [ + "codegen_lib.cc", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_codegen", + deps = [ + "//tensorflow_lite_support/codegen:android_java_generator", + "//tensorflow_lite_support/codegen:code_generator", + "@local_config_python//:python_headers", + "@pybind11", + ], +) + +py_binary( + name = "codegen", + srcs = [ + "codegen.py", + ], + python_version = "PY3", + deps = [ + ":_pywrap_codegen", + "@absl_py//absl:app", + "@absl_py//absl/flags", + "@absl_py//absl/logging", + ], +) diff --git a/tensorflow_lite_support/codegen/python/codegen.py b/tensorflow_lite_support/codegen/python/codegen.py new file mode 100644 index 00000000..7309a69d --- /dev/null +++ b/tensorflow_lite_support/codegen/python/codegen.py @@ -0,0 +1,104 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generates Android Java sources from a TFLite model with metadata.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import sys +from absl import app +from absl import flags +from absl import logging + +from tensorflow_lite_support.codegen.python import _pywrap_codegen + +FLAGS = flags.FLAGS + +flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.') +flags.DEFINE_string('destination', None, 'Path of destination of generation.') +flags.DEFINE_string('package_name', 'org.tensorflow.lite.support', + 'Name of generated java package to put the wrapper class.') +flags.DEFINE_string( + 'model_class_name', 'MyModel', + 'Name of generated wrapper class (should not contain package name).') +flags.DEFINE_string( + 'model_asset_path', '', + '(Optional) Path to the model in generated assets/ dir. If not set, ' + 'generator will use base name of input model.' +) + + +def get_model_buffer(path): + if not os.path.isfile(path): + logging.error('Cannot find model at path %s.', path) + with open(path, 'rb') as f: + buf = f.read() + return buf + + +def prepare_directory_for_file(file_path): + target_dir = os.path.dirname(file_path) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + return + if not os.path.isdir(target_dir): + logging.error('Cannot write to %s', target_dir) + + +def run_main(argv): + """Main function of the codegen.""" + + if len(argv) > 1: + logging.error('None flag arguments found: [%s]', ', '.join(argv[1:])) + + codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination) + model_buffer = get_model_buffer(FLAGS.model) + model_asset_path = FLAGS.model_asset_path + if not model_asset_path: + model_asset_path = os.path.basename(FLAGS.model) + result = codegen.generate(model_buffer, FLAGS.package_name, + FLAGS.model_class_name, model_asset_path) + error_message = codegen.get_error_message().strip() + if error_message: + logging.error(error_message) + if not result.files: + logging.error('Generation failed!') + return + + for each in result.files: + prepare_directory_for_file(each.path) + with open(each.path, 'w') as f: + f.write(each.content) + + logging.info('Generation succeeded!') + model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets', + model_asset_path) + prepare_directory_for_file(model_asset_path) + shutil.copy(FLAGS.model, model_asset_path) + logging.info('Model copied into assets!') + + +# Simple wrapper to make the code pip-friendly +def main(): + flags.mark_flag_as_required('model') + flags.mark_flag_as_required('destination') + app.run(main=run_main, argv=sys.argv) + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow_lite_support/codegen/python/codegen_lib.cc b/tensorflow_lite_support/codegen/python/codegen_lib.cc new file mode 100644 index 00000000..6b2cd5ea --- /dev/null +++ b/tensorflow_lite_support/codegen/python/codegen_lib.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/detail/common.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "tensorflow_lite_support/codegen/android_java_generator.h" +#include "tensorflow_lite_support/codegen/code_generator.h" + +namespace tflite { +namespace support { +namespace codegen { + +template <typename... Args> +using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>; + +PYBIND11_MODULE(_pywrap_codegen, m) { + pybind11::class_<AndroidJavaGenerator>(m, "AndroidJavaGenerator") + .def(pybind11::init<const std::string &>()) + .def("generate", + overload_cast_<const char *, const std::string &, + const std::string &, const std::string &>()( + &AndroidJavaGenerator::Generate)) + .def("get_error_message", &AndroidJavaGenerator::GetErrorMessage); + pybind11::class_<GenerationResult>(m, "GenerationResult") + .def(pybind11::init<>()) + .def_readwrite("files", &GenerationResult::files); + pybind11::class_<GenerationResult::File>(m, "GenerationResultFile") + .def(pybind11::init<>()) + .def_readwrite("path", &GenerationResult::File::path) + .def_readwrite("content", &GenerationResult::File::content); +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/codegen/utils.cc b/tensorflow_lite_support/codegen/utils.cc new file mode 100644 index 00000000..c75fc5fa --- /dev/null +++ b/tensorflow_lite_support/codegen/utils.cc @@ -0,0 +1,194 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/codegen/utils.h" + +#include <cstdarg> + +namespace tflite { +namespace support { +namespace codegen { + +int ErrorReporter::Warning(const char* format, ...) { + va_list args; + va_start(args, format); + return Report("[WARN] ", format, args); +} + +int ErrorReporter::Error(const char* format, ...) { + va_list args; + va_start(args, format); + return Report("[ERROR] ", format, args); +} + +int ErrorReporter::Report(const char* prefix, const char* format, + va_list args) { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << prefix << buf << std::endl; + return formatted; +} + +std::string ErrorReporter::GetMessage() { + std::string value = buffer_.str(); + buffer_.str(""); + return value; +} + +CodeWriter::CodeWriter(ErrorReporter* err) : indent_(0), err_(err) {} + +void CodeWriter::SetTokenValue(const std::string& token, + const std::string& value) { + value_map_[token] = value; +} + +const std::string CodeWriter::GetTokenValue(const std::string& token) const { + auto iter = value_map_.find(token); + if (iter == value_map_.end()) { + // Typically only Code Generator's call this function (or `Append`). It's + // their duty to make sure the token is valid, and requesting for an invalid + // token implicits flaws in the code generation logic. + err_->Error("Internal: Cannot find value with token '%s'", token.c_str()); + return ""; + } + return iter->second; +} + +void CodeWriter::SetIndentString(const std::string& indent_str) { + indent_str_ = indent_str; +} + +void CodeWriter::Indent() { indent_++; } + +void CodeWriter::Outdent() { indent_--; } + +std::string CodeWriter::GenerateIndent() const { + std::string res; + res.reserve(indent_str_.size() * indent_); + for (int i = 0; i < indent_; i++) { + res.append(indent_str_); + } + return res; +} + +void CodeWriter::Append(const std::string& text) { AppendInternal(text, true); } + +void CodeWriter::AppendNoNewLine(const std::string& text) { + AppendInternal(text, false); +} + +void CodeWriter::AppendInternal(const std::string& text, bool newline) { + // Prefix indent + if ((buffer_.empty() // nothing in the buffer + || buffer_.back() == '\n') // is on new line + // is writing on current line + && (!text.empty() && text[0] != '\n' && text[0] != '\r')) { + buffer_.append(GenerateIndent()); + } + // State machine variables + bool in_token = false; + int i = 0; + // Rough memory reserve + buffer_.reserve(buffer_.size() + text.size()); + std::string token_buffer; + // A simple LL1 analysis + while (i < text.size()) { + char cur = text[i]; + char cur_next = i == text.size() - 1 ? '\0' : text[i + 1]; // Set guardian + if (!in_token) { + if (cur == '{' && cur_next == '{') { // Enter token + in_token = true; + i += 2; + } else if (cur == '\n') { // We need to apply global indent here + buffer_.push_back(cur); + if (cur_next != '\0' && cur_next != '\n' && cur_next != '\r') { + buffer_.append(GenerateIndent()); + } + i += 1; + } else { + buffer_.push_back(cur); + i += 1; + } + } else { + if (cur == '}' && cur_next == '}') { // Close token + in_token = false; + const auto value = GetTokenValue(token_buffer); + buffer_.append(value); + token_buffer.clear(); + i += 2; + } else { + token_buffer.push_back(cur); + i += 1; + } + } + } + if (!token_buffer.empty()) { + // Typically only Code Generator's call this function. It's + // their duty to make sure the code (or template) has valid syntax, and + // unclosed "{{...}}" implicits severe error in the template. + err_->Error("Internal: Invalid template: {{token}} is not closed."); + } + if (newline) { + buffer_.push_back('\n'); + } +} + +void CodeWriter::NewLine() { Append(""); } + +void CodeWriter::Backspace(int n) { + buffer_.resize(buffer_.size() > n ? buffer_.size() - n : 0); +} + +std::string CodeWriter::ToString() const { return buffer_; } + +bool CodeWriter::IsStreamEmpty() const { return buffer_.empty(); } + +void CodeWriter::Clear() { + buffer_.clear(); + value_map_.clear(); + indent_ = 0; +} + +std::string SnakeCaseToCamelCase(const std::string& s) { + std::string t; + t.reserve(s.length()); + size_t i = 0; + // Note: Use simple string += for simplicity. + bool cap = false; + while (i < s.size()) { + const char c = s[i++]; + if (c == '_') { + cap = true; + } else if (cap) { + t += toupper(c); + cap = false; + } else { + t += c; + } + } + return t; +} + +std::string JoinPath(const std::string& a, const std::string& b) { + if (a.empty()) return b; + std::string a_fixed = a; + if (!a_fixed.empty() && a_fixed.back() == '/') a_fixed.pop_back(); + std::string b_fixed = b; + if (!b_fixed.empty() && b_fixed.front() == '/') b_fixed.erase(0, 1); + return a_fixed + "/" + b_fixed; +} + +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/codegen/utils.h b/tensorflow_lite_support/codegen/utils.h new file mode 100644 index 00000000..98768b6a --- /dev/null +++ b/tensorflow_lite_support/codegen/utils.h @@ -0,0 +1,127 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_UTILS_H_ +#define TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_UTILS_H_ + +#include <map> +#include <sstream> +#include <string> + +namespace tflite { +namespace support { +namespace codegen { + +/// Collects runtime error logs which could be showed later. +// TODO(b/150538286): Consider a better mechanism to simplify callsite code. +class ErrorReporter { + public: + int Warning(const char* format, ...); + int Error(const char* format, ...); + std::string GetMessage(); + + private: + int Report(const char* prefix, const char* format, va_list args); + std::stringstream buffer_; +}; + +/// Implements basic code generating with text templates. +/// +/// It could accept code templates and concatenate them into complete codes. A +/// template could contain named values. +/// +/// Example code: +/// CodeWriter code; +/// code.SetValue("NAME", "Foo"); +/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }"); +/// code.SetValue("NAME", "Bar"); +/// code.Append("void {{NAME}}() { printf("%s", "{{NAME}}"); }"); +/// +/// Output: +/// void Foo() { printf("%s", "Foo"); } +/// void Bar() { printf("%s", "Bar"); } +class CodeWriter { + public: + explicit CodeWriter(ErrorReporter* err); + /// Sets value to a token. When generating code with template, a string in a + /// pair of {{ and }} will be regarded as a token and replaced with the + /// corresponding value in code generation. + /// It rewrites if the token already has a value. + void SetTokenValue(const std::string& token, const std::string& value); + + /// Gets the current value set on the given token. + const std::string GetTokenValue(const std::string& token) const; + + /// Sets the unit indent string. For example, in Java it should be " ". + void SetIndentString(const std::string& indent); + + /// Increases the indent by a unit (the string set in SetIndentString). + void Indent(); + + /// Decreases the indent by a unit (the string set in SetIndentString). + void Outdent(); + + /// Generates the indentation string. + std::string GenerateIndent() const; + + /// Appends a piece of template codes to the stream. Every named value will be + /// replaced via the real value. A new line will always be appended at the + /// end. + void Append(const std::string& text); + + /// Appends a piece of template codes to the stream. Same with `Append`, but a + /// new line will not be appended at the end. + void AppendNoNewLine(const std::string& text); + + /// Appends a new line to the stream. + void NewLine(); + + /// Deletes the last N charaters in the stream. If the stream has less than N + /// characters, deletes all. + void Backspace(int n); + + std::string ToString() const; + + /// Checks if the internal string stream is empty. Note: This method has + // overhead. + bool IsStreamEmpty() const; + + /// Clears all the internal string stream and value map. + void Clear(); + + private: + void AppendInternal(const std::string& text, bool newline); + + std::string indent_str_; + int indent_; + + std::map<std::string, std::string> value_map_; + std::string buffer_; + + ErrorReporter* err_; +}; + +/// Converts foo_bar_name to fooBarName. It's callers duty to make sure given +/// string "s" is already in snake case; or unexpected behavior may occur. +std::string SnakeCaseToCamelCase(const std::string& s); + +/// Joins 2 parts of file path into one, connected by unix path seperator '/'. +/// It's callers duty to ensure the two parts are valid. +std::string JoinPath(const std::string& a, const std::string& b); + +} // namespace codegen +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_TENSORFLOW_LITE_SUPPORT_CODEGEN_UTILS_H_ diff --git a/tensorflow_lite_support/codegen/utils_test.cc b/tensorflow_lite_support/codegen/utils_test.cc new file mode 100644 index 00000000..3111f37b --- /dev/null +++ b/tensorflow_lite_support/codegen/utils_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/codegen/utils.h" + +#include <gtest/gtest.h> + +namespace tflite { +namespace support { +namespace codegen { +namespace { + +TEST(ErrorReporterTest, TestReportError) { + ErrorReporter err; + err.Error("some text"); + EXPECT_EQ(err.GetMessage(), "[ERROR] some text\n"); + EXPECT_EQ(err.GetMessage(), ""); +} + +TEST(CodeGeneratorTest, TestExample) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetTokenValue("NAME", "Foo"); + const std::string text = R"(void {{NAME}}() { printf("%s", "{{NAME}}"); })"; + writer.Append(text); + writer.SetTokenValue("NAME", "Bar"); + writer.Append(text); + EXPECT_EQ( + "void Foo() { printf(\"%s\", \"Foo\"); }\n" + "void Bar() { printf(\"%s\", \"Bar\"); }\n", + writer.ToString()); +} + +TEST(CodeGeneratorTest, TestInexistentToken) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetTokenValue("NAME", "Foo"); + const std::string text = R"(void {{name}}() {})"; + writer.Append(text); + EXPECT_EQ(err.GetMessage(), + "[ERROR] Internal: Cannot find value with token 'name'\n"); +} + +TEST(CodeGeneratorTest, TestUnclosedToken) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetTokenValue("NAME", "Foo"); + const std::string text = R"(void {{NAME}() {})"; + writer.Append(text); + EXPECT_EQ(err.GetMessage(), + "[ERROR] Internal: Invalid template: {{token}} is not closed.\n"); +} + +TEST(CodeGeneratorTest, TestIndentControl) { + ErrorReporter err; + CodeWriter writer(&err); + writer.SetIndentString(" "); + writer.Indent(); + writer.AppendNoNewLine("abcde"); // Will indent + EXPECT_EQ(" abcde", writer.ToString()); + writer.Clear(); + writer.Indent(); + writer.AppendNoNewLine("abc\n\nde"); + // The blank line will not indent + EXPECT_EQ(" abc\n\n de", writer.ToString()); + writer.Clear(); + writer.Indent(); + writer.Append("abc"); + writer.Outdent(); + writer.AppendNoNewLine("def"); + EXPECT_EQ(" abc\ndef", writer.ToString()); +} + +TEST(CaseConversionTest, TestSnakeToCamel) { + EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel")); + EXPECT_EQ("imACamel", SnakeCaseToCamelCase("im_a_camel_")); + EXPECT_EQ("ImACamel", SnakeCaseToCamelCase("_im_a_camel")); + EXPECT_EQ("", SnakeCaseToCamelCase("_")); + EXPECT_EQ("camel", SnakeCaseToCamelCase("camel")); +} + +} // namespace +} // namespace codegen +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/BUILD b/tensorflow_lite_support/custom_ops/BUILD new file mode 100644 index 00000000..35734cf2 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/BUILD @@ -0,0 +1,43 @@ +load("@org_tensorflow//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +# This will generate the tf_text_flex_delegate cc_library, which is a custom +# flex delegate that only contains ops in listed models. +tflite_flex_cc_library( + name = "tf_text_flex_delegate", + additional_deps = ["@org_tensorflow_text//tensorflow_text:ops_lib"], + models = [ + # TODO(b/160817619) Replace with a more complex model. + "testdata/sentencepiece_tokenizer_flex_op.tflite", + ], +) + +# bazel test --config=monolithic tensorflow_lite_support/custom_ops:tflite_inference_test +cc_test( + name = "tflite_inference_test", + srcs = ["tflite_inference_main.cc"], + args = ["--model=tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflite"], + data = ["//tensorflow_lite_support/custom_ops:testdata/sentencepiece_tokenizer_flex_op.tflite"], + deps = [ + ":tf_text_flex_delegate", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/tools:command_line_flags", + ] + select({ + "@org_tensorflow//tensorflow:android": [ + "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite", + ], + "@org_tensorflow//tensorflow:ios": [ + "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite", + ], + "//conditions:default": [ + "@org_tensorflow//tensorflow/core:lib", + ], + }), +) diff --git a/tensorflow_lite_support/custom_ops/kernel/BUILD b/tensorflow_lite_support/custom_ops/kernel/BUILD new file mode 100644 index 00000000..b9b11de9 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/BUILD @@ -0,0 +1,146 @@ +# Placeholder for internal Python strict test compatibility macro. +load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "whitespace_tokenizer", + srcs = ["whitespace_tokenizer.cc"], + hdrs = ["whitespace_tokenizer.h"], + deps = [ + "@org_tensorflow//tensorflow/lite:context", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@utf_archive//:utf", + ], +) + +cc_library( + name = "whitespace_tokenizer_op_resolver", + srcs = ["whitespace_tokenizer_op_resolver.cc"], + hdrs = ["whitespace_tokenizer_op_resolver.h"], + visibility = ["//visibility:public"], + deps = [ + ":whitespace_tokenizer", + "@org_tensorflow//tensorflow/lite:framework", + ], +) + +pybind_extension( + name = "_pywrap_whitespace_tokenizer_op_resolver", + srcs = ["whitespace_tokenizer_op_resolver_wrapper.cc"], + hdrs = ["whitespace_tokenizer_op_resolver.h"], + additional_exported_symbols = ["AddWhitespaceTokenizerCustomOp"], + module_name = "_pywrap_whitespace_tokenizer_op_resolver", + visibility = ["//visibility:public"], + deps = [ + ":whitespace_tokenizer_op_resolver", + "@local_config_python//:python_headers", + "@org_tensorflow//tensorflow/lite:framework", + "@pybind11", + ], +) + +cc_test( + name = "whitespace_tokenizer_test", + srcs = ["whitespace_tokenizer_test.cc"], + deps = [ + ":whitespace_tokenizer", + "@com_google_googletest//:gtest_main", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +py_test( + name = "whitespace_tokenizer_py_test", + srcs = ["whitespace_tokenizer_test.py"], + data = [ + "testdata/whitespace_tokenizer_flex_delegate.tflite", + "testdata/whitespace_tokenizer_to_ragged_1d_input.tflite", + "testdata/whitespace_tokenizer_to_ragged_2d_input.tflite", + "testdata/whitespace_tokenizer_to_tensor.tflite", + ], + main = "whitespace_tokenizer_test.py", + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_pywrap_whitespace_tokenizer_op_resolver", + # numpy dep, + # tensorflow dep, + # tensorflow_text dep, + "@absl_py//absl/logging", + "@absl_py//absl/testing:parameterized", + ], +) + +cc_library( + name = "ngrams", + srcs = ["ngrams.cc"], + hdrs = ["ngrams.h"], + deps = [ + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:context", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], +) + +cc_library( + name = "ngrams_op_resolver", + srcs = ["ngrams_op_resolver.cc"], + hdrs = ["ngrams_op_resolver.h"], + visibility = ["//visibility:public"], + deps = [ + ":ngrams", + "@org_tensorflow//tensorflow/lite:framework", + ], +) + +pybind_extension( + name = "_pywrap_ngrams_op_resolver", + srcs = ["ngrams_op_resolver_wrapper.cc"], + hdrs = ["ngrams_op_resolver.h"], + additional_exported_symbols = ["AddNgramsCustomOp"], + module_name = "_pywrap_ngrams_op_resolver", + visibility = ["//visibility:public"], + deps = [ + ":ngrams_op_resolver", + "@local_config_python//:python_headers", + "@org_tensorflow//tensorflow/lite:framework", + "@pybind11", + ], +) + +cc_test( + name = "ngrams_test", + srcs = ["ngrams_test.cc"], + deps = [ + ":ngrams", + "@com_google_googletest//:gtest_main", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +py_test( + name = "ngrams_py_test", + srcs = ["ngrams_test.py"], + main = "ngrams_test.py", + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":_pywrap_ngrams_op_resolver", + # tensorflow dep, + # tensorflow_text dep, + "//tensorflow_lite_support/custom_ops/python:tflite_text_api", + "@absl_py//absl/logging", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams.cc b/tensorflow_lite_support/custom_ops/kernel/ngrams.cc new file mode 100644 index 00000000..3831c63c --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ngrams.cc @@ -0,0 +1,208 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h" + +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace ngrams { + +// This TFLite op implements the text.ngrams when reduction_type = STRING_JOIN. +// +// Input: +// * data: A string tensor, or a ragged string tensor (a 1D string value tensor +// and one or more 1D int64 row_split tensors). +// +// Attributes: +// * width: scalar integer +// The width of the ngram window. +// * axis: scalar integer +// The axis to create ngrams along. For STRING_JOIN, this must be -1. +// * reduction_type: scalar string +// A string corresponding to the name of an enum value of text.Reduction +// Currently, only STRING_JOIN is supported. +// * string_separator: scalar string +// The separator string used to join tokens together. +// +// Output: +// * output: A string tensor that matches the rank of 'data'. Will be a ragged +// tensor if 'data' is a ragged tensor. + +// Both the input and output tensors use the same indices. +constexpr int kValues = 0; +constexpr int kRowSplitsStart = 1; + +// Reduction types. +constexpr char kStringJoin[] = "STRING_JOIN"; + +struct NgramsAttributes { + int width; + int axis; + std::string reduction_type; + std::string string_separator; + + explicit NgramsAttributes(const flexbuffers::Map& m) + : width(m["width"].AsInt32()), + axis(m["axis"].AsInt32()), + reduction_type(m["reduction_type"].ToString()), + string_separator(m["string_separator"].ToString()) {} +}; + +inline bool OutputIsTensor(TfLiteNode* node) { return NumOutputs(node) == 1; } +inline int NumRowSplits(TfLiteNode* node) { + return NumInputs(node) - kRowSplitsStart; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer); + return new NgramsAttributes(flexbuffers::GetRoot(buffer_t, length).AsMap()); +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<NgramsAttributes*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto& attributes = + *reinterpret_cast<NgramsAttributes*>(node->user_data); + + TF_LITE_ENSURE(context, attributes.reduction_type == kStringJoin); + TF_LITE_ENSURE(context, attributes.axis == -1); + + TfLiteTensor* output_values = GetOutput(context, node, kValues); + if (OutputIsTensor(node)) { + const TfLiteTensor* input_values = GetInput(context, node, kValues); + int values_num_dims = NumDimensions(input_values); + TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(values_num_dims); + for (int i = 0; i < values_num_dims; ++i) { + output_values_shape->data[i] = SizeOfDimension(input_values, i); + } + output_values_shape->data[values_num_dims - 1] = + std::max(0, SizeOfDimension(input_values, values_num_dims - 1) - + attributes.width + 1); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_values, + output_values_shape)); + return kTfLiteOk; + } + + SetTensorToDynamic(output_values); + // The row_splits tensors maintain their shape, because only the + // innermost dimension will change. + for (int i = kRowSplitsStart; i < NumOutputs(node); ++i) { + const TfLiteTensor* input_row_splits = GetInput(context, node, i); + TfLiteTensor* output_row_splits = GetOutput(context, node, i); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_row_splits), 1); + TfLiteIntArray* output_row_splits_shape = TfLiteIntArrayCreate(1); + output_row_splits_shape->data[0] = SizeOfDimension(input_row_splits, 0); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_row_splits, + output_row_splits_shape)); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto& attributes = + *reinterpret_cast<NgramsAttributes*>(node->user_data); + + // Storage for the dummy input and output row_splits used in the tensor case. + std::vector<int64_t> tensor_input_row_splits; + std::vector<int64_t> tensor_output_row_splits; + + const int64_t* input_row_splits; + int64_t* output_row_splits; + int n_row_splits = 0; + + const TfLiteTensor* input_values = GetInput(context, node, kValues); + + if (OutputIsTensor(node)) { + // Generate mock input and output innermost row_splits. + int64_t total_tokens = NumElements(input_values); + int64_t tokens_per_element = + SizeOfDimension(input_values, NumDimensions(input_values) - 1); + tensor_input_row_splits.reserve(total_tokens / tokens_per_element + 1); + tensor_output_row_splits.resize(total_tokens / tokens_per_element + 1); + for (int64_t i = 0; i <= total_tokens; i += tokens_per_element) { + tensor_input_row_splits.push_back(i); + } + input_row_splits = tensor_input_row_splits.data(); + output_row_splits = tensor_output_row_splits.data(); + n_row_splits = tensor_input_row_splits.size(); + } else { + int index = 0; + while (index < NumRowSplits(node) - 1) { + const TfLiteTensor* input_tensor_row_splits = + GetInput(context, node, kRowSplitsStart + index); + TfLiteTensor* output_tensor_row_splits = + GetOutput(context, node, kRowSplitsStart + index); + memcpy(output_tensor_row_splits->data.raw, + input_tensor_row_splits->data.raw, input_tensor_row_splits->bytes); + ++index; + } + + const TfLiteTensor* input_tensor_row_splits = + GetInput(context, node, kRowSplitsStart + index); + TfLiteTensor* output_tensor_row_splits = + GetOutput(context, node, kRowSplitsStart + index); + input_row_splits = input_tensor_row_splits->data.i64; + output_row_splits = output_tensor_row_splits->data.i64; + n_row_splits = SizeOfDimension(input_tensor_row_splits, 0); + } + + DynamicBuffer buffer; + StringRef separator; + separator.str = attributes.string_separator.c_str(); + separator.len = attributes.string_separator.length(); + int buffer_index = 0; + for (int i = 0; i < n_row_splits - 1; ++i) { + output_row_splits[i] = buffer_index; + std::vector<StringRef> tokens; + for (int j = input_row_splits[i]; j < input_row_splits[i + 1]; ++j) { + tokens.emplace_back(GetString(input_values, j)); + if (tokens.size() < attributes.width) continue; + tokens.erase(tokens.begin(), + tokens.begin() + tokens.size() - attributes.width); + buffer.AddJoinedString(tokens, separator); + ++buffer_index; + } + } + output_row_splits[n_row_splits - 1] = buffer_index; + + TfLiteTensor* output_values = GetOutput(context, node, kValues); + if (OutputIsTensor(node)) { + buffer.WriteToTensor(output_values, /*new_shape=*/nullptr); + } else { + buffer.WriteToTensorAsVector(output_values); + } + + return kTfLiteOk; +} + +} // namespace ngrams + +TfLiteRegistration* Register_tftext_Ngrams() { + static TfLiteRegistration r = {ngrams::Init, ngrams::Free, ngrams::Prepare, + ngrams::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams.h b/tensorflow_lite_support/custom_ops/kernel/ngrams.h new file mode 100644 index 00000000..56229065 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ngrams.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_H_ + +#include "tensorflow/lite/context.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_tftext_Ngrams(); + +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc new file mode 100644 index 00000000..b87fcac3 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.cc @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h" + +#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h" +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace tflite { +namespace ops { +namespace custom { + +void AddNgramsCustomOp(MutableOpResolver* resolver) { + resolver->AddCustom("tftext:Ngrams", Register_tftext_Ngrams()); +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h new file mode 100644 index 00000000..fc932688 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_OP_RESOLVER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_OP_RESOLVER_H_ + +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace tflite { +namespace ops { +namespace custom { + +// Adds the Ngrams custom op to an op resolver. +// This function can be loaded using dlopen. Since C++ function names get +// mangled, declare this function as extern C, so its name is unchanged. +extern "C" void AddNgramsCustomOp(MutableOpResolver* resolver); + +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_NGRAMS_OP_RESOLVER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver_wrapper.cc b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver_wrapper.cc new file mode 100644 index 00000000..82747309 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver_wrapper.cc @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow_lite_support/custom_ops/kernel/ngrams_op_resolver.h" + +PYBIND11_MODULE(_pywrap_ngrams_op_resolver, m) { + m.doc() = "_pywrap_ngrams_op_resolver"; + m.def( + "AddNgramsCustomOp", + [](uintptr_t resolver) { + tflite::ops::custom::AddNgramsCustomOp( + reinterpret_cast<tflite::MutableOpResolver*>(resolver)); + }, + "Op registerer function for the tftext:Ngrams custom op."); +} diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc b/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc new file mode 100644 index 00000000..91ef47af --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_test.cc @@ -0,0 +1,293 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/ngrams.h" + +#include <string> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace ngrams { +namespace test { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +} // namespace + +class NgramsModel : public SingleOpModel { + public: + // Constructor for testing the op with a tf.Tensor + NgramsModel(int width, const std::string& string_separator, + const std::vector<std::string>& input_values, + const std::vector<int>& input_shape) { + input_values_ = AddInput(TensorType_STRING); + output_values_ = AddOutput(TensorType_STRING); + + BuildCustomOp(width, string_separator); + + BuildInterpreter({input_shape}); + PopulateStringTensor(input_values_, input_values); + Invoke(); + } + + // Constructor for the op with a tf.RaggedTensor + // Note: This interface uses row_lengths, as they're closer to the + // dimensions in a TensorShape, but internally everything is row_splits. + NgramsModel(int width, const std::string& string_separator, + const std::vector<std::string>& input_values, + const std::vector<std::vector<int64_t>> nested_row_lengths) { + std::vector<std::vector<int>> input_shapes; + input_shapes.reserve(nested_row_lengths.size() + 1); + + input_values_ = AddInput(TensorType_STRING); + input_shapes.push_back({static_cast<int>(input_values.size())}); + output_values_ = AddOutput(TensorType_STRING); + + input_row_splits_.reserve(nested_row_lengths.size()); + output_row_splits_.reserve(nested_row_lengths.size()); + for (int i = 0; i < nested_row_lengths.size(); ++i) { + input_row_splits_.push_back(AddInput(TensorType_INT64)); + input_shapes.push_back( + {static_cast<int>(nested_row_lengths[i].size() + 1)}); + output_row_splits_.push_back(AddOutput(TensorType_INT64)); + } + + BuildCustomOp(width, string_separator); + + BuildInterpreter(input_shapes); + PopulateStringTensor(input_values_, input_values); + for (int i = 0; i < nested_row_lengths.size(); ++i) { + std::vector<int64_t> row_splits; + row_splits.reserve(nested_row_lengths[i].size() + 1); + int64_t index = 0; + row_splits.push_back(index); + for (int64_t row_length : nested_row_lengths[i]) { + index += row_length; + row_splits.push_back(index); + } + PopulateTensor(input_row_splits_[i], row_splits); + } + Invoke(); + } + + std::vector<int> GetValuesTensorShape() { + return GetTensorShape(output_values_); + } + + std::vector<std::string> ExtractValuesTensorVector() { + std::vector<std::string> r; + TfLiteTensor* tensor = interpreter_->tensor(output_values_); + int n = GetStringCount(tensor); + for (int i = 0; i < n; ++i) { + StringRef ref = GetString(tensor, i); + r.emplace_back(ref.str, ref.len); + } + return r; + } + + int GetNumNestedRowLengths() { return output_row_splits_.size(); } + + std::vector<int> GetRowLengthsTensorShape(int i) { + std::vector<int> shape = GetTensorShape(output_row_splits_[i]); + --shape[0]; + return shape; + } + + std::vector<int64_t> ExtractRowLengthsTensorVector(int i) { + std::vector<int64_t> row_splits = + ExtractVector<int64_t>(output_row_splits_[i]); + std::vector<int64_t> row_lengths; + row_lengths.reserve(row_splits.size() - 1); + int64_t head = row_splits[0]; + for (int i = 1; i < row_splits.size(); ++i) { + int64_t tail = row_splits[i]; + row_lengths.push_back(tail - head); + head = tail; + } + return row_lengths; + } + + private: + void BuildCustomOp(int width, const std::string& string_separator) { + flexbuffers::Builder fbb; + size_t start_map = fbb.StartMap(); + fbb.Int("width", width); + fbb.String("string_separator", string_separator); + fbb.Int("axis", -1); + fbb.String("reduction_type", "STRING_JOIN"); + fbb.EndMap(start_map); + fbb.Finish(); + + SetCustomOp("tftext:Ngrams", fbb.GetBuffer(), Register_tftext_Ngrams); + } + + int input_values_; + std::vector<int> input_row_splits_; + int output_values_; + std::vector<int> output_row_splits_; +}; + +TEST(NgramsTest, TensorSingleSequenceWidthTwo) { + NgramsModel m(2, " ", {"this", "is", "a", "test"}, std::vector<int>{4}); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("this is", "is a", "a test")); +} + +TEST(NgramsTest, TensorSingleSequenceWidthThree) { + NgramsModel m(3, " ", {"this", "is", "a", "test"}, std::vector<int>{4}); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(2)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("this is a", "is a test")); +} + +TEST(NgramsTest, TensorSingleSequenceLongerSeparator) { + NgramsModel m(2, "...", {"this", "is", "a", "test"}, std::vector<int>{4}); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("this...is", "is...a", "a...test")); +} + +TEST(NgramsTest, TensorSingleSequenceWidthTooLong) { + NgramsModel m(5, " ", {"this", "is", "a", "test"}, std::vector<int>{4}); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(0)); + EXPECT_THAT(m.ExtractValuesTensorVector(), ElementsAre()); +} + +TEST(NgramsTest, TensorMultidimensionalInputWidthTwo) { + NgramsModel m(2, " ", + { + "0,0,0", "0,0,1", "0,0,2", "0,0,3", // + "0,1,0", "0,1,1", "0,1,2", "0,1,3", // + "0,2,0", "0,2,1", "0,2,2", "0,2,3", // + "1,0,0", "1,0,1", "1,0,2", "1,0,3", // + "1,1,0", "1,1,1", "1,1,2", "1,1,3", // + "1,2,0", "1,2,1", "1,2,2", "1,2,3", // + }, + std::vector<int>{2, 3, 4}); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(2, 3, 3)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAreArray({ + "0,0,0 0,0,1", "0,0,1 0,0,2", "0,0,2 0,0,3", // + "0,1,0 0,1,1", "0,1,1 0,1,2", "0,1,2 0,1,3", // + "0,2,0 0,2,1", "0,2,1 0,2,2", "0,2,2 0,2,3", // + "1,0,0 1,0,1", "1,0,1 1,0,2", "1,0,2 1,0,3", // + "1,1,0 1,1,1", "1,1,1 1,1,2", "1,1,2 1,1,3", // + "1,2,0 1,2,1", "1,2,1 1,2,2", "1,2,2 1,2,3", // + })); +} + +TEST(NgramsTest, RaggedTensorSingleSequenceWidthTwo) { + std::vector<std::vector<int64_t>> nested_row_lengths; + nested_row_lengths.push_back({4}); + NgramsModel m(2, " ", {"this", "is", "a", "test"}, + nested_row_lengths); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("this is", "is a", "a test")); + ASSERT_THAT(m.GetNumNestedRowLengths(), 1); + EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1)); + EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(3)); +} + +TEST(NgramsTest, RaggedTensorSingleSequenceWidthThree) { + std::vector<std::vector<int64_t>> nested_row_lengths; + nested_row_lengths.push_back({4}); + NgramsModel m(3, " ", {"this", "is", "a", "test"}, nested_row_lengths); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(2)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("this is a", "is a test")); + ASSERT_THAT(m.GetNumNestedRowLengths(), 1); + EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1)); + EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(2)); +} + +TEST(NgramsTest, RaggedTensorSingleSequenceLongerSeparator) { + std::vector<std::vector<int64_t>> nested_row_lengths; + nested_row_lengths.push_back({4}); + NgramsModel m(2, "<>", {"this", "is", "a", "test"}, nested_row_lengths); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("this<>is", "is<>a", "a<>test")); + ASSERT_THAT(m.GetNumNestedRowLengths(), 1); + EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1)); + EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(3)); +} + +TEST(NgramsTest, RaggedTensorSingleSequenceWidthTooLong) { + std::vector<std::vector<int64_t>> nested_row_lengths; + nested_row_lengths.push_back({4}); + NgramsModel m(5, " ", {"this", "is", "a", "test"}, nested_row_lengths); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(0)); + EXPECT_THAT(m.ExtractValuesTensorVector(), ElementsAre()); + ASSERT_THAT(m.GetNumNestedRowLengths(), 1); + EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(1)); + EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(0)); +} + +TEST(NgramsTest, RaggedTensorMultidimensionalInputWidthTwo) { + std::vector<std::vector<int64_t>> nested_row_lengths; + nested_row_lengths.push_back({4, 2, 1}); + nested_row_lengths.push_back({5, 4, 3, 2, 2, 3, 4, 6}); + NgramsModel m(2, " ", + { + "0,0,0", "0,0,1", "0,0,2", "0,0,3", "0,0,4", // + "0,1,0", "0,1,1", "0,1,2", "0,1,3", // + "0,2,0", "0,2,1", "0,2,2", // + "0,3,0", "0,3,1", // + "1,0,0", "1,0,1", // + "1,1,0", "1,1,1", "1,1,2", // + "1,2,0", "1,2,1", "1,2,2", "1,2,3", // + "2,0,0", "2,0,1", "2,0,2", "2,0,3", "2,0,4", "2,0,5", // + }, + nested_row_lengths); + + std::vector<std::string> expected_values = { + "0,0,0 0,0,1", "0,0,1 0,0,2", "0,0,2 0,0,3", "0,0,3 0,0,4", // + "0,1,0 0,1,1", "0,1,1 0,1,2", "0,1,2 0,1,3", // + "0,2,0 0,2,1", "0,2,1 0,2,2", // + "0,3,0 0,3,1", // + "1,0,0 1,0,1", // + "1,1,0 1,1,1", "1,1,1 1,1,2", // + "1,2,0 1,2,1", "1,2,1 1,2,2", "1,2,2 1,2,3", // + "2,0,0 2,0,1", "2,0,1 2,0,2", "2,0,2 2,0,3", "2,0,3 2,0,4", + "2,0,4 2,0,5", // + }; + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(expected_values.size())); + EXPECT_THAT(m.ExtractValuesTensorVector(), ElementsAreArray(expected_values)); + ASSERT_THAT(m.GetNumNestedRowLengths(), 2); + EXPECT_THAT(m.GetRowLengthsTensorShape(0), ElementsAre(3)); + EXPECT_THAT(m.ExtractRowLengthsTensorVector(0), ElementsAre(4, 2, 1)); + EXPECT_THAT(m.GetRowLengthsTensorShape(1), ElementsAre(8)); + EXPECT_THAT(m.ExtractRowLengthsTensorVector(1), + ElementsAre(4, 3, 2, 1, 1, 2, 3, 5)); +} + +} // namespace test +} // namespace ngrams +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py b/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py new file mode 100644 index 00000000..e52ca285 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ngrams_test.py @@ -0,0 +1,266 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Tests for tensorflow_lite_support.custom_ops.ngrams.""" + +import os +import sys +import timeit + +from absl import logging +from absl.testing import parameterized +import tensorflow as tf +import tensorflow_text as tf_text +from tensorflow.lite.python import interpreter as interpreter_wrapper # pylint: disable=g-direct-tensorflow-import +from tensorflow_lite_support.custom_ops.python import tflite_text_api + +# Force loaded shared object symbols to be globally visible. This is needed so +# that the interpreter_wrapper, in one .so file, can see the op resolver +# in a different .so file. Note that this may already be set by default. +# pylint: disable=g-import-not-at-top,g-bad-import-order,unused-import +if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'): + sys.setdlopenflags(sys.getdlopenflags() | os.RTLD_GLOBAL) +from tensorflow_lite_support.custom_ops.kernel import _pywrap_ngrams_op_resolver + +TEST_CASES = [ + [['this', 'is', 'a', 'test']], + [['one']], + [['two', 'tokens'], ['a', 'b']], + [['has', 'three', 'tokens'], ['a', 'b', 'c'], ['0', '1', '2']], + [['a', 'ragged', 'tensor'], ['a'], ['0', '1']], + [[['a', 'multidimensional', 'test', 'case'], ['a', 'b', 'c', 'd', 'e']], + [['0', '1', '2', '3', '4', '5']]], +] + +INVOKES_FOR_SINGLE_OP_BENCHMARK = 1000 +INVOKES_FOR_FLEX_DELEGATE_BENCHMARK = 100 + + +class NgramsTest(parameterized.TestCase): + + _models = {} + + def _make_model(self, rank, width, ragged_tensor=False, flex=False): + temp_dir = self.create_tempdir().full_path + + key = (rank, width, ragged_tensor, flex) + if key in self._models: + return self._models[key] + + ngrams = tf_text.ngrams if flex else tflite_text_api.ngrams + + if ragged_tensor: + input_signature = [tf.TensorSpec(shape=[None], dtype=tf.string)] + rs = rank - 1 + input_signature += [tf.TensorSpec(shape=[None], dtype=tf.int64)] * rs + + class Model(tf.Module): + + @tf.function(input_signature=input_signature) + def __call__(self, values, *args): + row_splits = list(args) + row_splits.reverse() + input_tensor = tf.RaggedTensor.from_nested_row_splits( + flat_values=values, nested_row_splits=tuple(row_splits)) + output_tensor = ngrams( + input_tensor, width, reduction_type=tf_text.Reduction.STRING_JOIN) + output = [output_tensor.flat_values] + output.extend(list(output_tensor.nested_row_splits)) + output.reverse() + return tuple(output) + + tf.saved_model.save(Model(), temp_dir) + else: + shape = [None] * rank + + class Model(tf.Module): + + @tf.function( + input_signature=[tf.TensorSpec(shape=shape, dtype=tf.string)]) + def __call__(self, input_tensor): + return ngrams( + input_tensor, width, reduction_type=tf_text.Reduction.STRING_JOIN) + + tf.saved_model.save(Model(), temp_dir) + + converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir) + converter.inference_type = tf.float32 + converter.inference_input_type = tf.float32 + converter.allow_custom_ops = not flex + if flex: + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS + ] + model = converter.convert() + self._models[key] = model + return model + + @parameterized.parameters([t] for t in TEST_CASES) + def test_width_2_tensor_equivalence(self, test_case): + input_tensor = tf.ragged.constant(test_case).to_tensor() + tf_output = tf_text.ngrams( + input_tensor, 2, reduction_type=tf_text.Reduction.STRING_JOIN) + + rank = input_tensor.shape.rank + model = self._make_model(rank, 2, ragged_tensor=False, flex=False) + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=model, custom_op_registerers=['AddNgramsCustomOp']) + interpreter.resize_tensor_input(0, input_tensor.shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + input_tensor.numpy()) + interpreter.invoke() + tflite_output = interpreter.get_tensor( + interpreter.get_output_details()[0]['index']) + + self.assertEqual(tf_output.numpy().tolist(), tflite_output.tolist()) + + @parameterized.parameters([t] for t in TEST_CASES) + def test_width_3_tensor_equivalence(self, test_case): + input_tensor = tf.ragged.constant(test_case).to_tensor() + tf_output = tf_text.ngrams( + input_tensor, 3, reduction_type=tf_text.Reduction.STRING_JOIN) + + rank = input_tensor.shape.rank + model = self._make_model(rank, 3, ragged_tensor=False, flex=False) + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=model, custom_op_registerers=['AddNgramsCustomOp']) + interpreter.resize_tensor_input(0, input_tensor.shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + input_tensor.numpy()) + interpreter.invoke() + tflite_output = interpreter.get_tensor( + interpreter.get_output_details()[0]['index']) + self.assertEqual(tf_output.numpy().tolist(), tflite_output.tolist()) + + @parameterized.parameters([t] for t in TEST_CASES) + def test_width_2_ragged_tensor_equivalence(self, test_case): + input_tensor = tf.ragged.constant(test_case) + tf_output = tf_text.ngrams( + input_tensor, 2, reduction_type=tf_text.Reduction.STRING_JOIN) + + rank = input_tensor.shape.rank + model = self._make_model(rank, 2, ragged_tensor=True, flex=False) + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=model, custom_op_registerers=['AddNgramsCustomOp']) + interpreter.resize_tensor_input(0, input_tensor.flat_values.shape) + for r in range(rank - 1): + interpreter.resize_tensor_input(r + 1, + input_tensor.nested_row_splits[r].shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + input_tensor.flat_values.numpy()) + for r in range(rank - 1): + interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'], + input_tensor.nested_row_splits[r].numpy()) + interpreter.invoke() + tflite_output_values = interpreter.get_tensor( + interpreter.get_output_details()[0]['index']) + self.assertEqual(tf_output.flat_values.numpy().tolist(), + tflite_output_values.tolist()) + for i in range(rank - 1): + tflite_output_cur_row_splits = interpreter.get_tensor( + interpreter.get_output_details()[i + 1]['index']) + self.assertEqual(tf_output.nested_row_splits[i].numpy().tolist(), + tflite_output_cur_row_splits.tolist()) + + @parameterized.parameters([t] for t in TEST_CASES) + def test_width_3_ragged_tensor_equivalence(self, test_case): + input_tensor = tf.ragged.constant(test_case) + tf_output = tf_text.ngrams( + input_tensor, 3, reduction_type=tf_text.Reduction.STRING_JOIN) + + rank = input_tensor.shape.rank + model = self._make_model(rank, 3, ragged_tensor=True, flex=False) + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=model, custom_op_registerers=['AddNgramsCustomOp']) + interpreter.resize_tensor_input(0, input_tensor.flat_values.shape) + for r in range(rank - 1): + interpreter.resize_tensor_input(r + 1, + input_tensor.nested_row_splits[r].shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + input_tensor.flat_values.numpy()) + for r in range(rank - 1): + interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'], + input_tensor.nested_row_splits[r].numpy()) + interpreter.invoke() + tflite_output_values = interpreter.get_tensor( + interpreter.get_output_details()[0]['index']) + self.assertEqual(tf_output.flat_values.numpy().tolist(), + tflite_output_values.tolist()) + for i in range(rank - 1): + tflite_output_cur_row_splits = interpreter.get_tensor( + interpreter.get_output_details()[i + 1]['index']) + self.assertEqual(tf_output.nested_row_splits[i].numpy().tolist(), + tflite_output_cur_row_splits.tolist()) + + def test_latency(self): + latency_op = 0.0 + for test_case in TEST_CASES: + input_tensor = tf.ragged.constant(test_case) + + rank = input_tensor.shape.rank + model = self._make_model(rank, 3, ragged_tensor=True, flex=False) + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=model, custom_op_registerers=['AddNgramsCustomOp']) + interpreter.resize_tensor_input(0, input_tensor.flat_values.shape) + for r in range(rank - 1): + interpreter.resize_tensor_input(r + 1, + input_tensor.nested_row_splits[r].shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + input_tensor.flat_values.numpy()) + for r in range(rank - 1): + interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'], + input_tensor.nested_row_splits[r].numpy()) + start_time = timeit.default_timer() + for _ in range(INVOKES_FOR_SINGLE_OP_BENCHMARK): + interpreter.invoke() + latency_op = latency_op + timeit.default_timer() - start_time + latency_op = latency_op / ( + INVOKES_FOR_SINGLE_OP_BENCHMARK * len(TEST_CASES)) + + latency_flex = 0.0 + for test_case in TEST_CASES: + input_tensor = tf.ragged.constant(test_case) + + rank = input_tensor.shape.rank + model = self._make_model(rank, 3, ragged_tensor=True, flex=True) + interpreter = interpreter_wrapper.Interpreter(model_content=model) + interpreter.resize_tensor_input(0, input_tensor.flat_values.shape) + for r in range(rank - 1): + interpreter.resize_tensor_input(r + 1, + input_tensor.nested_row_splits[r].shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + input_tensor.flat_values.numpy()) + for r in range(rank - 1): + interpreter.set_tensor(interpreter.get_input_details()[r + 1]['index'], + input_tensor.nested_row_splits[r].numpy()) + start_time = timeit.default_timer() + for _ in range(INVOKES_FOR_FLEX_DELEGATE_BENCHMARK): + interpreter.invoke() + latency_flex = latency_flex + timeit.default_timer() - start_time + latency_flex = latency_flex / ( + INVOKES_FOR_FLEX_DELEGATE_BENCHMARK * len(TEST_CASES)) + + logging.info('Latency (single op): %fms', latency_op * 1000.0) + logging.info('Latency (flex delegate): %fms', latency_flex * 1000.0) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD b/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD new file mode 100644 index 00000000..a512cdc8 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ragged/BUILD @@ -0,0 +1,81 @@ +# RaggedTensors suppport in TFLite + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "ragged_tensor_to_tensor_tflite", + srcs = ["ragged_tensor_to_tensor_tflite.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "@flatbuffers", + "@org_tensorflow//tensorflow/core/util:ragged_to_dense_util_common", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + "@org_tensorflow//tensorflow/lite/kernels/internal:types", + ], +) + +cc_test( + name = "ragged_tensor_to_tensor_tflite_test", + srcs = ["ragged_tensor_to_tensor_tflite_test.cc"], + deps = [ + ":ragged_tensor_to_tensor_tflite", + "@com_google_googletest//:gtest_main", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) + +cc_library( + name = "py_tflite_registerer", + srcs = ["py_tflite_registerer.cc"], + hdrs = ["py_tflite_registerer.h"], + deps = [ + ":ragged_tensor_to_tensor_tflite", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +cc_library( + name = "ragged_range_tflite", + srcs = ["ragged_range_tflite.cc"], + visibility = [ + "//visibility:public", + ], + deps = [ + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + "@org_tensorflow//tensorflow/lite/kernels/internal:types", + ], +) + +cc_test( + name = "ragged_range_tflite_test", + srcs = ["ragged_range_tflite_test.cc"], + deps = [ + ":ragged_range_tflite", + "@com_google_googletest//:gtest_main", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], +) diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD b/tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD new file mode 100644 index 00000000..650ab90b --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ragged/py/BUILD @@ -0,0 +1,27 @@ +# Python wrapper used for test. + +load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = [ + "//tensorflow_lite_support:users", + ], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "pywrap_tflite_registerer", + srcs = [ + "pywrap_tflite_registerer.cc", + ], + additional_exported_symbols = ["TFLite_RaggedTensorToTensorRegisterer"], + module_name = "pywrap_tflite_registerer", + srcs_version = "PY3ONLY", + deps = [ + "//tensorflow_lite_support/custom_ops/kernel/ragged:py_tflite_registerer", + "@local_config_python//:python_headers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@pybind11", + ], +) diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/py/pywrap_tflite_registerer.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/py/pywrap_tflite_registerer.cc new file mode 100644 index 00000000..0b9432a9 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ragged/py/pywrap_tflite_registerer.cc @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h" + +PYBIND11_MODULE(pywrap_tflite_registerer, m) { + m.doc() = R"pbdoc( + pywrap_tflite_registerer + A module with a wrapper that adds to a Python wrapper for TFLite + ragged_tensor_to_tensor. + )pbdoc"; + m.def( + "TFLite_RaggedTensorToTensorRegisterer", + [](uintptr_t resolver) { + TFLite_RaggedTensorToTensorRegisterer( + reinterpret_cast<tflite::MutableOpResolver*>(resolver)); + }, + R"pbdoc( + The function that adds RaggedTensorToTensor to the TFLite interpreter. + )pbdoc"); +} diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.cc new file mode 100644 index 00000000..7c93d8b1 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.cc @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h" + +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace tflite { +namespace ops { +namespace custom { +TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR(); +} // namespace custom +} // namespace ops +} // namespace tflite + +extern "C" void TFLite_RaggedTensorToTensorRegisterer( + tflite::MutableOpResolver* resolver) { + resolver->AddCustom("RaggedTensorToTensor", + tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR()); +} diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h b/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h new file mode 100644 index 00000000..ade3c5c1 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ragged/py_tflite_registerer.h @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_ + +#include "tensorflow/lite/mutable_op_resolver.h" + +// C-function that is called from the Python Wrapper. + +extern "C" void TFLite_RaggedTensorToTensorRegisterer( + tflite::MutableOpResolver *resolver); + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_RAGGED_PY_TFLITE_REGISTERER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc new file mode 100644 index 00000000..a35a6db9 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite.cc @@ -0,0 +1,192 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <algorithm> +#include <functional> + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace ragged { +namespace ragged_range { +namespace { +constexpr int kInputStarts = 0; +constexpr int kInputLimits = 1; +constexpr int kInputDeltas = 2; + +constexpr int kOutputNestedSplits = 0; +constexpr int kOutputDenseValues = 1; + +TfLiteIntArray* IntArrayFromInt(int x) { + TfLiteIntArray* result = TfLiteIntArrayCreate(1); + result->data[0] = x; + return result; +} + +// Returns the number of elements in the specified range. +template <typename T, typename SPLITS_TYPE> +SPLITS_TYPE RangeSize(T start, T limit, T delta) { + if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) { + return 0; + } + // The following is copied from tensorflow::RangeOp::Compute(). + return ( + std::is_integral<T>::value + ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) + : std::ceil(std::abs((limit - start) / delta))); +} + +template <typename T, typename SPLITS_TYPE> +TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor& input_starts = + context->tensors[node->inputs->data[kInputStarts]]; + TfLiteTensor& input_limits = + context->tensors[node->inputs->data[kInputLimits]]; + TfLiteTensor& input_deltas = + context->tensors[node->inputs->data[kInputDeltas]]; + // Determine which tensors we need to broadcast. + const bool broadcast_starts = NumElements(&input_starts) == 1; + const bool broadcast_limits = NumElements(&input_limits) == 1; + const bool broadcast_deltas = NumElements(&input_deltas) == 1; + + // nrows (number of output rows) is the size of the non-broadcast inputs, + // or 1 if all inputs are scalars. + std::vector<int> in_sizes; + if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]); + if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]); + if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]); + if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes), + std::not_equal_to<>()) != std::end(in_sizes)) { + context->ReportError( + context, + "Invalid argument: starts, limits, and deltas must have the " + "same shape"); + return kTfLiteError; + } + + const SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes.front(); + + const T* starts = GetTensorData<T>(&input_starts); + const T* limits = GetTensorData<T>(&input_limits); + const T* deltas = GetTensorData<T>(&input_deltas); + + TfLiteTensor& rt_nested_splits_out = + context->tensors[node->outputs->data[kOutputNestedSplits]]; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, &rt_nested_splits_out, + IntArrayFromInt(nrows + 1))); + SPLITS_TYPE* rt_nested_splits = + GetTensorData<SPLITS_TYPE>(&rt_nested_splits_out); + rt_nested_splits[0] = 0; + + for (int row = 0; row < nrows; ++row) { + const T start = broadcast_starts ? starts[0] : starts[row]; + const T limit = broadcast_limits ? limits[0] : limits[row]; + const T delta = broadcast_deltas ? deltas[0] : deltas[row]; + if (delta == 0) { + context->ReportError(context, "Invalid argument: Requires delta != 0"); + return kTfLiteError; + } + rt_nested_splits[row + 1] = + rt_nested_splits[row] + RangeSize<T, SPLITS_TYPE>(start, limit, delta); + } + const SPLITS_TYPE nvals = rt_nested_splits[nrows]; + + TfLiteTensor& rt_dense_values_out = + context->tensors[node->outputs->data[kOutputDenseValues]]; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, &rt_dense_values_out, + IntArrayFromInt(nvals))); + T* rt_dense_values = GetTensorData<T>(&rt_dense_values_out); + int value_index = 0; + for (int row = 0; row < nrows; ++row) { + const SPLITS_TYPE row_size = + rt_nested_splits[row + 1] - rt_nested_splits[row]; + T value = broadcast_starts ? starts[0] : starts[row]; + const T delta = broadcast_deltas ? deltas[0] : deltas[row]; + for (SPLITS_TYPE i = 0; i < row_size; ++i) { + rt_dense_values[value_index++] = value; + value += delta; + } + } + return kTfLiteOk; +} + +template <typename SPLITS_TYPE> +TfLiteStatus EvalSplitsT(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor& rt_dense_values_out = + context->tensors[node->outputs->data[kOutputDenseValues]]; + switch (rt_dense_values_out.type) { + case kTfLiteInt32: + return EvalT<int32_t, SPLITS_TYPE>(context, node); + case kTfLiteInt64: + return EvalT<int64_t, SPLITS_TYPE>(context, node); + case kTfLiteFloat32: + return EvalT<float, SPLITS_TYPE>(context, node); + case kTfLiteFloat64: + return EvalT<double, SPLITS_TYPE>(context, node); + default: + context->ReportError(context, + "Invalid argument: Not supported VALUES type"); + return kTfLiteError; + } +} +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Set outputs dynamic. + TfLiteTensor& nested_splits = + context->tensors[node->outputs->data[kOutputNestedSplits]]; + SetTensorToDynamic(&nested_splits); + TfLiteTensor& dense_values = + context->tensors[node->outputs->data[kOutputDenseValues]]; + SetTensorToDynamic(&dense_values); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor& rt_nested_splits_out = + context->tensors[node->outputs->data[kOutputNestedSplits]]; + switch (rt_nested_splits_out.type) { + case kTfLiteInt32: + return EvalSplitsT<int32_t>(context, node); + case kTfLiteInt64: + return EvalSplitsT<int64_t>(context, node); + default: + context->ReportError(context, + "Invalid argument: Not supported ROW_SPLITS type"); + return kTfLiteError; + } +} + +} // namespace ragged_range +} // namespace ragged +TfLiteRegistration* Register_RAGGED_RANGE() { + static TfLiteRegistration r = {nullptr /*Initialize*/, nullptr /*Free*/, + ragged::ragged_range::Prepare, + ragged::ragged_range::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc new file mode 100644 index 00000000..54cf4459 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_range_tflite_test.cc @@ -0,0 +1,155 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <initializer_list> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace ops { +namespace custom { +TfLiteRegistration* Register_RAGGED_RANGE(); +} // namespace custom +} // namespace ops + +namespace { + +template <typename T> +class RaggedRangeOpModel : public SingleOpModel { + public: + static TensorType GetType(); + + RaggedRangeOpModel(const std::vector<T>& start, const std::vector<T>& limits, + const std::vector<T>& deltas) { + const TensorType value_type = GetType(); + std::vector<std::vector<int>> shapes; + input_start_ = AddInput(value_type); + shapes.push_back({static_cast<int>(start.size())}); + input_limits_ = AddInput(value_type); + shapes.push_back({static_cast<int>(limits.size())}); + input_deltas_ = AddInput(value_type); + shapes.push_back({static_cast<int>(deltas.size())}); + + output_splits_ = AddOutput(TensorType_INT32); + output_values_ = AddOutput(value_type); + + SetCustomOp("RaggedRange", {}, ops::custom::Register_RAGGED_RANGE); + BuildInterpreter(shapes); + + PopulateTensor(input_start_, start); + PopulateTensor(input_limits_, limits); + PopulateTensor(input_deltas_, deltas); + } + + std::vector<int32> GetSplits() { + return ExtractVector<int32>(output_splits_); + } + std::vector<T> GetValues() const { return ExtractVector<T>(output_values_); } + + protected: + int input_start_ = -1; + int input_limits_ = -1; + int input_deltas_ = -1; + + int output_splits_ = -1; + int output_values_ = -1; +}; + +template <> +TensorType RaggedRangeOpModel<int32>::GetType() { + return TensorType_INT32; +} + +template <> +TensorType RaggedRangeOpModel<float>::GetType() { + return TensorType_FLOAT32; +} + +TEST(RaggedRangeOpTest, IntValues) { + RaggedRangeOpModel<int32> model({0, 5, 8, 5}, // Starts. + {8, 7, 8, 1}, // Limits. + {2, 1, 1, -1}); // Deltas. + model.Invoke(); + + EXPECT_THAT(model.GetSplits(), + testing::UnorderedElementsAreArray({0, 4, 6, 6, 10})); + EXPECT_THAT(model.GetValues(), testing::UnorderedElementsAreArray( + {0, 2, 4, 6, 5, 6, 5, 4, 3, 2})); +} + +TEST(RaggedRangeOpTest, FloatValues) { + RaggedRangeOpModel<float> model({0, 5, 8, 5}, // Starts. + {8, 7, 8, 1}, // Limits. + {2, 1, 1, -1}); // Deltas. + model.Invoke(); + + EXPECT_THAT(model.GetSplits(), + testing::UnorderedElementsAreArray({0, 4, 6, 6, 10})); + EXPECT_THAT(model.GetValues(), testing::UnorderedElementsAreArray( + {0, 2, 4, 6, 5, 6, 5, 4, 3, 2})); +} + +TEST(RaggedRangeOpTest, BroadcastDelta) { + RaggedRangeOpModel<int32> model({0, 5, 8}, // Starts. + {8, 7, 8}, // Limits. + {1}); // Deltas. + model.Invoke(); + + EXPECT_THAT(model.GetSplits(), + testing::UnorderedElementsAreArray({0, 8, 10, 10})); + EXPECT_THAT(model.GetValues(), testing::UnorderedElementsAreArray( + {0, 1, 2, 3, 4, 5, 6, 7, 5, 6})); +} + +TEST(RaggedRangeOpTest, BroadcastStartDeltas) { + RaggedRangeOpModel<int32> model({0}, // Starts. + {10}, // Limits. + {2, 1}); // Deltas. + model.Invoke(); + + EXPECT_THAT(model.GetSplits(), + testing::UnorderedElementsAreArray({0, 5, 15})); + EXPECT_THAT(model.GetValues(), + testing::UnorderedElementsAreArray( + {0, 2, 4, 6, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9})); +} + +TEST(RaggedRangeOpTest, BadDeltas) { + RaggedRangeOpModel<int32> model({0, 5, 8, 5}, // Starts. + {8, 7, 7, 9}, // Limits. + {0, 1, 1, 1}); // Deltas. + EXPECT_EQ(model.InvokeUnchecked(), kTfLiteError); +} + +TEST(RaggedRangeOpTest, ZeroRange) { + RaggedRangeOpModel<int32> model({0, 7}, // Starts. + {8, 5}, // Limits. + {1, 1}); // Deltas. + model.Invoke(); + EXPECT_THAT(model.GetSplits(), testing::UnorderedElementsAreArray({0, 8, 8})); + EXPECT_THAT(model.GetValues(), + testing::UnorderedElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7})); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc new file mode 100644 index 00000000..09ac76c7 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite.cc @@ -0,0 +1,690 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> +#include <memory> + +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/core/util/ragged_to_dense_util_common.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace ragged { +namespace ragged_tensor_to_tensor { +namespace { + +constexpr int kShapeInput = 0; +constexpr int kValuesInput = 1; +constexpr int kDefaultValueInput = 2; +constexpr int kFirstPartitionInputIndex = 3; + +constexpr int kOutputTensor = 0; + +constexpr char kRowPartitionTypesAttr[] = "row_partition_types"; + +struct ConversionAttributes { + std::vector<tensorflow::RowPartitionType> partition_types; + int ragged_rank = 0; + + tensorflow::RowPartitionType GetRowPartitionTypeByDimension( + int dimension) const { + if (partition_types.front() == + tensorflow::RowPartitionType::FIRST_DIM_SIZE) { + return partition_types[dimension + 1]; + } else { + return partition_types[dimension]; + } + } +}; +template <typename INDEX_TYPE> +int GetFirstDimensionSizeT(TfLiteContext* context, + const TfLiteTensor& first_partition_input, + const ConversionAttributes* attributes) { + const tensorflow::RowPartitionType first_partition_type = + attributes->partition_types.front(); + switch (first_partition_type) { + case tensorflow::RowPartitionType::FIRST_DIM_SIZE: + return *GetTensorData<INDEX_TYPE>(&first_partition_input); + case tensorflow::RowPartitionType::VALUE_ROWIDS: + context->ReportError(context, + "Cannot handle VALUE_ROWIDS in first dimension."); + return -1; + case tensorflow::RowPartitionType::ROW_SPLITS: { + const auto shape = GetTensorShape(&first_partition_input); + return shape.Dims(0) - 1; + } + + default: + context->ReportError( + context, "Cannot handle type ", + RowPartitionTypeToString(first_partition_type).c_str()); + return -1; + } +} + +int GetFirstDimensionSize(TfLiteContext* context, + const TfLiteTensor& first_partition_input, + const ConversionAttributes* attributes) { + switch (first_partition_input.type) { + case kTfLiteInt32: + return GetFirstDimensionSizeT<int32_t>(context, first_partition_input, + attributes); + case kTfLiteInt64: + return GetFirstDimensionSizeT<int64_t>(context, first_partition_input, + attributes); + default: + context->ReportError(context, + "Not supported row partitioning tensor type"); + return -1; + } +} + +bool ValidateDefaultValueShape(TfLiteContext* context, + const RuntimeShape& default_value_shape, + const RuntimeShape& /*value_shape*/) { + // TF implementation also checks that shapes are not defined, not needed in + // TFLite. + // TODO(mgubin): Only scalar default value sizes are supported. + if (default_value_shape.FlatSize() != 1) { + context->ReportError(context, "Only scalar default value is supported"); + return false; + } + return true; +} + +RuntimeShape TensorShapeFromTensor(const TfLiteTensor& tensor) { + // TODO(mgubin): No checks, see + // third_party/tensorflow/core/kernels/list_kernels.cc + const RuntimeShape tensor_shape(tensor.dims->size, tensor.dims->data); + if (0 == tensor.dims->size) { + // If the input tensor is scalar then the shape is empty (also scalar). + return RuntimeShape{}; + } + RuntimeShape result(tensor_shape.FlatSize()); + switch (tensor.type) { + case kTfLiteInt32: { + for (int i = 0; i < tensor_shape.FlatSize(); ++i) { + result.SetDim(i, GetTensorData<int32_t>(&tensor)[i]); + } + } break; + case kTfLiteInt64: { + for (int i = 0; i < tensor_shape.FlatSize(); ++i) { + result.SetDim(i, GetTensorData<int64_t>(&tensor)[i]); + } + } break; + default: { + // Checked in Prepare. + } + } + return result; +} + +const TfLiteTensor* GetRowPartitionTensor( + const ConversionAttributes& conversion_attributes, TfLiteContext* context, + TfLiteNode* node, int dimension) { + if (conversion_attributes.partition_types.front() == + tensorflow::RowPartitionType::FIRST_DIM_SIZE) { + return &context->tensors[node->inputs->data[kFirstPartitionInputIndex + 1 + + dimension]]; + } else { + return &context->tensors[node->inputs + ->data[kFirstPartitionInputIndex + dimension]]; + } +} + +int GetMaxWidthValueRowID(const TfLiteTensor* tensor) { + const RuntimeShape tensor_shape(tensor->dims->size, tensor->dims->data); + const int index_length = tensor_shape.FlatSize(); + if (index_length == 0) { + return 0; + } + auto value_rowids = [tensor](int index) { + switch (tensor->type) { + case kTfLiteInt32: + return static_cast<int>(tensor->data.i32[index]); + case kTfLiteInt64: + return static_cast<int>(tensor->data.i64[index]); + default: + // TODO(mgubin): Add error checks. + return 0; + } + }; + int first_equal_index = 0; + int first_equal_index_value = value_rowids(0); + int max_width = 0; + for (int i = 0; i < index_length; ++i) { + const int value = value_rowids(i); + if (value != first_equal_index_value) { + first_equal_index_value = value; + max_width = std::max(i - first_equal_index, max_width); + first_equal_index = i; + } + } + return std::max(index_length - first_equal_index, max_width); +} + +int GetMaxWidthRowSplit(const TfLiteTensor* tensor) { + const RuntimeShape tensor_shape(tensor->dims->size, tensor->dims->data); + const int tensor_length = tensor_shape.FlatSize(); + if (tensor_length == 0 || tensor_length == 1) { + return 0; + } + auto value_rowsplit = [tensor](int index) { + switch (tensor->type) { + case kTfLiteInt32: + return static_cast<int>(tensor->data.i32[index]); + case kTfLiteInt64: + return static_cast<int>(tensor->data.i64[index]); + default: + // TODO(mgubin): Add error checks. + return 0; + } + }; + int max_width = 1; + int prev_split = value_rowsplit(0); + for (int i = 1; i < tensor_length; ++i) { + const int split = value_rowsplit(i); + max_width = std::max(max_width, split - prev_split); + prev_split = split; + } + return max_width; +} + +int GetMaxWidth(const ConversionAttributes& conversion_attributes, + TfLiteContext* context, TfLiteNode* node, int dimension) { + const TfLiteTensor* tensor = GetRowPartitionTensor( + conversion_attributes, context, node, dimension - 1); + switch (conversion_attributes.GetRowPartitionTypeByDimension(dimension - 1)) { + case tensorflow::RowPartitionType::VALUE_ROWIDS: + return GetMaxWidthValueRowID(tensor); + case tensorflow::RowPartitionType::ROW_SPLITS: + return GetMaxWidthRowSplit(tensor); + default: + context->ReportError(context, "Cannot handle partition type"); + return -1; + } +} + +RuntimeShape CombineRaggedTensorToTensorShapes( + int ragged_rank, const RuntimeShape& output_shape, + const RuntimeShape& value_shape) { + // TODO(mgubin): No checks, see + // third_party/tensorflow/core/ops/ragged_to_dense_util.cc + RuntimeShape result(output_shape); + if (output_shape.DimensionsCount() == 0) { + const int output_shape_rank = ragged_rank + value_shape.DimensionsCount(); + result.Resize(output_shape_rank); + for (int i = 0; i < output_shape_rank; ++i) { + result.SetDim(i, -1); + } + } + const int need_to_set = + output_shape.DimensionsCount() - value_shape.DimensionsCount(); + for (int i = 1; i < value_shape.DimensionsCount(); ++i) { + result.SetDim(need_to_set + i, value_shape.Dims(i)); + } + return result; +} + +RuntimeShape CalculateOutputSize( + const ConversionAttributes& conversion_attributes, TfLiteContext* context, + TfLiteNode* node, int first_dimension, int ragged_rank, + const TfLiteTensor& values, const TfLiteTensor& default_value, + const TfLiteTensor& output_shape) { + RuntimeShape values_shape(values.dims->size, values.dims->data); + RuntimeShape default_value_shape(default_value.dims->size, + default_value.dims->data); + + if (!ValidateDefaultValueShape(context, default_value_shape, values_shape)) { + return {}; + } + RuntimeShape output_shape_shape = TensorShapeFromTensor(output_shape); + + RuntimeShape result_shape = CombineRaggedTensorToTensorShapes( + ragged_rank, output_shape_shape, values_shape); + if (result_shape.Dims(0) < 0) { + result_shape.SetDim(0, first_dimension); + } + for (int i = 1; i <= ragged_rank; ++i) { + if (result_shape.Dims(i) < 0) { + result_shape.SetDim(i, + GetMaxWidth(conversion_attributes, context, node, i)); + } + } + return result_shape; +} + +TfLiteIntArray* IntArrayFromShape(const RuntimeShape& shape) { + TfLiteIntArray* result = TfLiteIntArrayCreate(shape.DimensionsCount()); + for (int i = 0; i < shape.DimensionsCount(); ++i) { + result->data[i] = shape.Dims(i); + } + return result; +} + +/** + * The output_index represents the index in the output tensor + * where the first element of a particular dimension would be written. + * If it is -1, it indicates that the index is out of scope. + * Example, given first_dimension = 10, first_dimension_output = 6, + * and output_index_multiplier = 100: + * result = [0 100 200 300 400 500 -1 -1 -1 -1] + * If first_dimension_output = 11 instead, then: + * result = [0 100 200 300 400 500 600 700 800 900] + */ +void CalculateFirstParentOutputIndex(int first_dimension, + int output_index_multiplier, + int first_dimension_output, + std::vector<int>* result) { + const int min_dimension = std::min(first_dimension, first_dimension_output); + result->reserve(first_dimension); + int current_output_index = 0; + for (int i = 0; i < min_dimension; + ++i, current_output_index += output_index_multiplier) { + result->push_back(current_output_index); + } + for (int i = min_dimension; i < first_dimension; ++i) { + result->push_back(-1); + } +} +// Calculate the output index of the first element of a list. +// The parent_output_index is the same computation for the previous list. +// -1 indicates an element or list that is out of range. +// The output_index_multiplier is the number of output indices one moves +// forward for each column. +// E.g., given: +// value_rowids:[0 1 2 2 2 3 5 5 6] +// parent_output_index:[1000 1100 2000 2100 -1 3000 4000] +// output_index_multiplier: 10 +// output_size: 2 +// You get: +// result = [1000 1100 2000 2010 -1 2100 -1 -1 3000] +// result[0] = parent_output_index[value_rowids[0]] +// result[1] = parent_output_index[value_rowids[1]] +// result[2] = parent_output_index[value_rowids[2]] +// result[3] = parent_output_index[value_rowids[2] + 10] +// result[4] = -1 because it is the third element the size is 2. +// result[5] = parent_output_index[value_rowids[3]] +// result[6] = -1 because parent_output_index[value_rowids[6]] == -1 +// result[7] = -1 because parent_output_index[value_rowids[6]] == -1 +// result[8] = parent_output_index[value_rowids[7]] +void CalculateOutputIndexValueRowID(const TfLiteTensor& value_rowids, + const std::vector<int>& parent_output_index, + int output_index_multiplier, + int output_size, std::vector<int>* result) { + const RuntimeShape tensor_shape(value_rowids.dims->size, + value_rowids.dims->data); + const int index_size = tensor_shape.FlatSize(); + result->reserve(index_size); + if (index_size == 0) { + return; + } + + auto value_rowids_val = [value_rowids](int index) { + switch (value_rowids.type) { + case kTfLiteInt32: + return static_cast<int>(value_rowids.data.i32[index]); + case kTfLiteInt64: + return static_cast<int>(value_rowids.data.i64[index]); + default: + // TODO(mgubin): Add error checks. + return 0; + } + }; + int current_output_column = 0; + int current_value_rowid = value_rowids_val(0); + // DCHECK_LT(current_value_rowid, parent_output_index.size()); + int current_output_index = parent_output_index[current_value_rowid]; + result->push_back(current_output_index); + for (int i = 1; i < index_size; ++i) { + int next_value_rowid = value_rowids_val(i); + if (next_value_rowid == current_value_rowid) { + if (current_output_index >= 0) { + ++current_output_column; + if (current_output_column < output_size) { + current_output_index += output_index_multiplier; + } else { + current_output_index = -1; + } + } + } else { + current_output_column = 0; + current_value_rowid = next_value_rowid; + // DCHECK_LT(next_value_rowid, parent_output_index.size()); + current_output_index = parent_output_index[next_value_rowid]; + } + result->push_back(current_output_index); + } + // DCHECK_EQ(result->size(), value_rowids.size()); +} + +void CalculateOutputIndexRowSplit(const TfLiteTensor& row_split, + const std::vector<int>& parent_output_index, + int output_index_multiplier, int output_size, + std::vector<int>* result) { + const RuntimeShape row_split_shape(row_split.dims->size, + row_split.dims->data); + const int row_split_size = row_split_shape.FlatSize(); + auto row_split_val = [row_split](int index) { + switch (row_split.type) { + case kTfLiteInt32: + return static_cast<int>(row_split.data.i32[index]); + case kTfLiteInt64: + return static_cast<int>(row_split.data.i64[index]); + default: + // TODO(mgubin): Add error checks. + return 0; + } + }; + if (row_split_size > 0) { + result->reserve(row_split_val(row_split_size - 1)); + } + for (int i = 0; i < row_split_size - 1; ++i) { + const int row_length = row_split_val(i + 1) - row_split_val(i); + int real_length = std::min(output_size, row_length); + int parent_output_index_current = parent_output_index[i]; + + if (parent_output_index_current == -1) { + real_length = 0; + } + for (int j = 0; j < real_length; ++j) { + result->push_back(parent_output_index_current); + parent_output_index_current += output_index_multiplier; + } + for (int j = 0; j < row_length - real_length; ++j) { + result->push_back(-1); + } + } + // if (row_split_size > 0) { + // DCHECK_EQ(result->size(), row_split(row_split_size - 1)); + //} +} + +TfLiteStatus CalculateOutputIndex( + const ConversionAttributes& conversion_attributes, TfLiteContext* context, + TfLiteNode* node, int dimension, + const std::vector<int>& parent_output_index, int output_index_multiplier, + int output_size, std::vector<int>* result) { + const TfLiteTensor* row_partition_tensor = + GetRowPartitionTensor(conversion_attributes, context, node, dimension); + auto partition_type = + conversion_attributes.GetRowPartitionTypeByDimension(dimension); + switch (partition_type) { + case tensorflow::RowPartitionType::VALUE_ROWIDS: + CalculateOutputIndexValueRowID(*row_partition_tensor, parent_output_index, + output_index_multiplier, output_size, + result); + return kTfLiteOk; + case tensorflow::RowPartitionType::ROW_SPLITS: + CalculateOutputIndexRowSplit(*row_partition_tensor, parent_output_index, + output_index_multiplier, output_size, + result); + return kTfLiteOk; + default: + context->ReportError(context, "Unsupported partition type"); + return kTfLiteError; + } +} + +template <typename VALUE_TYPE> +void SetOutputT(TfLiteContext* context, int ragged_rank, + const std::vector<int>& output_index, + const TfLiteTensor& values_tensor, + const TfLiteTensor& default_value_tensor, + TfLiteTensor* output_tensor) { + const VALUE_TYPE* values_base = GetTensorData<VALUE_TYPE>(&values_tensor); + VALUE_TYPE* output_base = GetTensorData<VALUE_TYPE>(output_tensor); + const VALUE_TYPE* default_value = + GetTensorData<VALUE_TYPE>(&default_value_tensor); + + RuntimeShape output_shape = GetTensorShape(output_tensor); + RuntimeShape element_shape = + RuntimeShape(output_shape.DimensionsCount() - ragged_rank - 1, + output_shape.DimsData() + ragged_rank + 1); + + // element_shape.RemoveDimRange(0, ragged_rank + 1); + const int value_element_size = element_shape.FlatSize(); + size_t output_index_size = output_index.size(); + + // Loop through the output_index vector, finding contiguous regions that + // should be copied. Once we find the end of a contiguous region, copy it + // and add any necessary padding (with default_value). + int src_start = 0; // Start of contiguous region (in values) + int dst_start = 0; // Destination for contiguous region (in output) + int dst_end = 0; // Destination for contiguous region (in output) + for (int src_i = 0; src_i <= output_index_size; ++src_i) { + // dst_i is the destination where the value at src_i should be copied. + int dst_i = src_i < output_index_size ? output_index[src_i] : -1; + + // If we're still in a contiguous region, then update dst_end go to the + // next src_i. + if (dst_i == dst_end) { + ++dst_end; + continue; + } + + // We found the end of contiguous region. This can be because we found + // a gap (dst_i > dst_end), or a source value that shouldn't be copied + // because it's out-of-bounds (dst_i == -1), or the end of the tensor + // (dst_i = -1). + if (dst_start < dst_end) { + // Copy the contiguous region. + const VALUE_TYPE* src = values_base + src_start * value_element_size; + VALUE_TYPE* dst = output_base + dst_start * value_element_size; + int nvals = (dst_end - dst_start) * value_element_size; + std::copy(src, src + nvals, dst); + // copy_array<VALUE_TYPE, int>(dst, src, nvals); + } + + // Add any necessary padding (w/ default_value). + if (src_i >= output_index_size) { + // We reached the end of values: pad to the end of output. + const int output_size = output_shape.FlatSize(); + dst_i = output_size / value_element_size; + } + if (dst_i > dst_end) { + std::fill(output_base + dst_end * value_element_size, + output_base + dst_i * value_element_size, *default_value); + dst_end = dst_i; + } + + // Update indices. + if (dst_i < 0) { + // src_i should be skipped -- leave it out of the contiguous region. + src_start = src_i + 1; + dst_start = dst_end; + } else { + // src_i should be copied -- include it in the contiguous region. + src_start = src_i; + dst_start = dst_end; + dst_end = dst_start + 1; + } + } +} + +void SetOutput(TfLiteContext* context, int ragged_rank, + const std::vector<int>& output_index, + const TfLiteTensor& values_tensor, + const TfLiteTensor& default_value_tensor, + TfLiteTensor* output_tensor) { + switch (output_tensor->type) { + case kTfLiteInt32: + SetOutputT<int32_t>(context, ragged_rank, output_index, values_tensor, + default_value_tensor, output_tensor); + break; + case kTfLiteInt64: + SetOutputT<int64_t>(context, ragged_rank, output_index, values_tensor, + default_value_tensor, output_tensor); + break; + case kTfLiteFloat32: + SetOutputT<float>(context, ragged_rank, output_index, values_tensor, + default_value_tensor, output_tensor); + break; + default: + context->ReportError(context, "Not supported values type"); + } +} + +} // namespace + +void* Initialize(TfLiteContext* context, const char* buffer, size_t length) { + auto attributes = std::make_unique<ConversionAttributes>(); + + const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer); + + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + // TODO (mgubin): Converting flat buffer to a vector of strings looks not very + // effective but simple. A cleaner way is needed. + const flexbuffers::TypedVector row_partition_types_attr = + m[kRowPartitionTypesAttr].AsTypedVector(); + std::vector<std::string> row_partition_types_attr_strings; + row_partition_types_attr_strings.reserve(row_partition_types_attr.size()); + for (int i = 0; i < row_partition_types_attr.size(); ++i) { + row_partition_types_attr_strings.emplace_back( + row_partition_types_attr[i].AsString().str()); + } + attributes->partition_types = + tensorflow::GetRowPartitionTypesHelper(row_partition_types_attr_strings); + if (attributes->partition_types.size() != + row_partition_types_attr_strings.size()) { + context->ReportError(context, "Can't parse partition type attribute"); + return nullptr; + } + attributes->ragged_rank = + tensorflow::GetRaggedRank(attributes->partition_types); + return attributes.release(); +} +void Free(TfLiteContext* /*context*/, void* buffer) { + ConversionAttributes* attributes = + reinterpret_cast<ConversionAttributes*>(buffer); + delete attributes; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const ConversionAttributes* attributes = + reinterpret_cast<ConversionAttributes*>(node->user_data); + if (attributes == nullptr) { + // Parsing attributes failed, can't prepare. + context->ReportError(context, "Attributes are not initialized"); + return kTfLiteError; + } + // The output tensor need to be set to dynamic because it can have different + // size. + TfLiteTensor& output_tensor = + context->tensors[node->outputs->data[kOutputTensor]]; + SetTensorToDynamic(&output_tensor); + + // Check that input shape tensor is int32 or int64 + TfLiteTensor& input_shape = context->tensors[node->inputs->data[kShapeInput]]; + if (input_shape.type != kTfLiteInt32 && input_shape.type != kTfLiteInt64) { + context->ReportError(context, + "Input form tensor could be only int32 or int64"); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const ConversionAttributes* attributes = + reinterpret_cast<ConversionAttributes*>(node->user_data); + TfLiteTensor& input_shape = context->tensors[node->inputs->data[kShapeInput]]; + TfLiteTensor& input_values = + context->tensors[node->inputs->data[kValuesInput]]; + TfLiteTensor& default_value = + context->tensors[node->inputs->data[kDefaultValueInput]]; + // TODO (mgubin): Only scallar default value is supported. + if (RuntimeShape(default_value.dims->size, default_value.dims->data) + .FlatSize() != 1) { + context->ReportError(context, "Only scallar default value is supported"); + return kTfLiteError; + } + TfLiteTensor& first_partition_input = + context->tensors[node->inputs->data[kFirstPartitionInputIndex]]; + + // Calculate dimensions. + const int first_dimension = + GetFirstDimensionSize(context, first_partition_input, attributes); + if (first_dimension < 0) { + return kTfLiteError; + } + RuntimeShape output_shape = CalculateOutputSize( + *attributes, context, node, first_dimension, attributes->ragged_rank, + input_values, default_value, input_shape); + if (output_shape.DimensionsCount() == 0) { + return kTfLiteError; + } + + std::vector<int> multiplier; + multiplier.resize(attributes->ragged_rank + 1); + multiplier.back() = 1; + for (int i = multiplier.size() - 2; i >= 0; --i) { + multiplier[i] = multiplier[i + 1] * output_shape.Dims(i + 1); + } + + // Allocate output tensor. + TfLiteTensor& output_tensor = + context->tensors[node->outputs->data[kOutputTensor]]; + + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, &output_tensor, + IntArrayFromShape(output_shape))); + + // Copy data. + const int full_size = multiplier.front() * output_shape.Dims(0); + if (full_size > 0) { + std::vector<int> output_index, new_output_index; + int nvals = input_values.dims->data[0]; + output_index.reserve(nvals); + new_output_index.reserve(nvals); + + CalculateFirstParentOutputIndex(first_dimension, multiplier[0], + output_shape.Dims(0), &output_index); + for (int i = 1; i <= attributes->ragged_rank; ++i) { + TF_LITE_ENSURE_OK( + context, CalculateOutputIndex( + *attributes, context, node, i - 1, output_index, + multiplier[i], output_shape.Dims(i), &new_output_index)); + output_index.swap(new_output_index); + new_output_index.clear(); + } + + SetOutput(context, attributes->ragged_rank, output_index, input_values, + default_value, &output_tensor); + } + return kTfLiteOk; +} + +} // namespace ragged_tensor_to_tensor +} // namespace ragged + +TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR() { + static TfLiteRegistration r = {ragged::ragged_tensor_to_tensor::Initialize, + ragged::ragged_tensor_to_tensor::Free, + ragged::ragged_tensor_to_tensor::Prepare, + ragged::ragged_tensor_to_tensor::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc new file mode 100644 index 00000000..b1cde57c --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/ragged/ragged_tensor_to_tensor_tflite_test.cc @@ -0,0 +1,283 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <initializer_list> +#include <string> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { +namespace ops { +namespace custom { +TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR(); +} // namespace custom +} // namespace ops + +namespace { + +class RaggedTensorToTensorOpModel : public SingleOpModel { + public: + RaggedTensorToTensorOpModel(int output_shape_dims, + std::initializer_list<int> values_shape, + std::initializer_list<std::initializer_list<int>> + partition_tensors_shapes, + std::vector<std::string> partition_types, + TensorType value_type = TensorType_FLOAT32, + TensorType index_type = TensorType_INT32) { + // A structure to collect shapes for the input. + std::vector<std::vector<int>> shapes; + input_shape_ = AddInput(index_type); + shapes.push_back({output_shape_dims}); + input_values_ = AddInput(value_type); + shapes.emplace_back(values_shape); + input_default_values_ = AddInput(value_type); + shapes.push_back({1}); + for (const auto& p : partition_tensors_shapes) { + partition_tensors_.push_back(AddInput(TensorType_INT32)); + shapes.emplace_back(p); + } + output_ = AddOutput(value_type); + + flexbuffers::Builder fbb; + size_t start = fbb.StartMap(); + { + size_t start = fbb.StartVector("row_partition_types"); + for (const auto& s : partition_types) { + fbb.String(s); + } + fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); + } + fbb.Int("num_row_partition_tensors", partition_types.size()); + fbb.EndMap(start); + fbb.Finish(); + SetCustomOp("RaggedTensorToTensor", fbb.GetBuffer(), + ops::custom::Register_RAGGED_TENSOR_TO_TENSOR); + BuildInterpreter(shapes); + } + + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); } + std::vector<int32> GetOutputInt() { return ExtractVector<int32>(output_); } + + void InvokeFloat(const std::vector<int>& shape, + const std::vector<float>& values, float default_value, + const std::vector<std::vector<int>>& partition_values) { + PopulateTensor(input_shape_, shape); + PopulateTensor(input_values_, values); + PopulateTensor(input_default_values_, {default_value}); + for (int i = 0; i < partition_values.size(); ++i) { + PopulateTensor(partition_tensors_[i], partition_values[i]); + } + SingleOpModel::Invoke(); + } + void InvokeInt(const std::vector<int>& shape, + const std::vector<int32>& values, int32 default_value, + const std::vector<std::vector<int>>& partition_values) { + PopulateTensor(input_shape_, shape); + PopulateTensor(input_values_, values); + PopulateTensor(input_default_values_, {default_value}); + for (int i = 0; i < partition_values.size(); ++i) { + PopulateTensor(partition_tensors_[i], partition_values[i]); + } + SingleOpModel::Invoke(); + } + + private: + int input_shape_; + int input_values_; + int input_default_values_; + std::vector<int> partition_tensors_; + int output_; +}; + +TEST(RaggedTensorToTensorTest, RaggedTensorToTensor) { + // indices = [2, 1, 0, 3] + // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] + // params.shape = [4, None] + RaggedTensorToTensorOpModel model( + 2, // output_shape_dims + {9}, // values_shape + {{1}, {9}}, // partition_tensors_shapes + std::vector<std::string>({"FIRST_DIM_SIZE", "VALUE_ROWIDS"})); + model.InvokeFloat({4, 4}, // shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values + 1.5, // default_value + std::vector<std::vector<int>>( + {std::vector<int>({4}), + std::vector<int>({0, 0, 0, 2, 2, 2, 2, 3, 3})})); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({4, 4})); + EXPECT_THAT(model.GetOutputFloat(), + testing::ElementsAreArray({.1, .2, .3, 1.5, 1.5, 1.5, 1.5, 1.5, + .4, .5, .6, .7, .8, .9, 1.5, 1.5})); +} + +TEST(RaggedTensorToTensorTest, RaggedTensorToTensorRowSplits) { + // indices = [2, 1, 0, 3] + // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] + RaggedTensorToTensorOpModel model(2, // output_shape_dims + {9}, // values_shape + {{5}}, // partition_tensors_shapes + std::vector<std::string>({"ROW_SPLITS"})); + model.InvokeFloat( + {4, 4}, // shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values + 1.5, // default_value + std::vector<std::vector<int>>({std::vector<int>({0, 3, 3, 7, 9})})); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({4, 4})); + EXPECT_THAT(model.GetOutputFloat(), + testing::ElementsAreArray({.1, .2, .3, 1.5, 1.5, 1.5, 1.5, 1.5, + .4, .5, .6, .7, .8, .9, 1.5, 1.5})); +} + +TEST(RaggedTensorToTensorTest, RaggedTensorToTensor_3DParams) { + // params = [ + // [[]], + // [[.1, .2], [.3]], + // [], + // [[.4, .5], [.6, .7, .8]], + // [[.9]] + // ] + RaggedTensorToTensorOpModel model( + 3, // output_shape_dims + {9}, // values_shape + {{1}, {6}, {9}}, // partition_tensors_shapes + std::vector<std::string>( + {"FIRST_DIM_SIZE", "VALUE_ROWIDS", "VALUE_ROWIDS"})); + model.InvokeFloat( + {5, 2, 3}, // shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values + 1.5, // default_value + std::vector<std::vector<int>>( + {std::vector<int>({5}), std::vector<int>({0, 1, 1, 3, 3, 4}), + std::vector<int>({1, 1, 2, 3, 3, 4, 4, 4, 5})})); + + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({5, 2, 3})); + EXPECT_THAT(model.GetOutputFloat(), + testing::ElementsAreArray({1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .1, .2, + 1.5, .3, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, + 1.5, 1.5, .4, .5, 1.5, .6, .7, .8, + .9, 1.5, 1.5, 1.5, 1.5, 1.5})); +} + +TEST(RaggedTensorToTensorOpTest, RaggedTensorToTensor_3DParamsRowSplits) { + // params = [ + // [[]], + // [[.1, .2], [.3]], + // [], + // [[.4, .5], [.6, .7, .8]], + // [[.9]] + // ] + RaggedTensorToTensorOpModel model( + 3, // output_shape_dims + {9}, // values_shape + {{6}, {7}}, // partition_tensors_shapes + std::vector<std::string>({"ROW_SPLITS", "ROW_SPLITS"})); + model.InvokeFloat( + {5, 2, 3}, // shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values + 1.5, // default_value + std::vector<std::vector<int>>({std::vector<int>({0, 1, 3, 3, 5, 6}), + std::vector<int>({0, 0, 2, 3, 5, 8, 9})})); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({5, 2, 3})); + EXPECT_THAT(model.GetOutputFloat(), + testing::ElementsAreArray({1.5, 1.5, 1.5, 1.5, 1.5, 1.5, .1, .2, + 1.5, .3, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, + 1.5, 1.5, .4, .5, 1.5, .6, .7, .8, + .9, 1.5, 1.5, 1.5, 1.5, 1.5})); +} + +TEST(RaggedTensorToTensorTest, RaggedTensorToTensor_3DParamsRowSplits2) { + // params = [ + // [[0, 1, 2], []], + // [], + // [[3]] + // ] + + RaggedTensorToTensorOpModel model( + 3, // output_shape_dims + {4}, // values_shape + {{4}, {4}}, // partition_tensors_shapes + std::vector<std::string>({"ROW_SPLITS", "ROW_SPLITS"}), TensorType_INT32); + model.InvokeInt( + {3, 2, 3}, // shape + {0, 1, 2, 3}, // values + 5, // default_value + std::vector<std::vector<int>>( + {std::vector<int>({0, 2, 2, 3}), std::vector<int>({0, 3, 3, 4})})); + + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 2, 3})); + + EXPECT_THAT(model.GetOutputInt(), + testing::ElementsAreArray( + {0, 1, 2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 5, 5, 5, 5, 5})); +} + +TEST(RaggedTensorToTensorTest, RaggedTensorToTensorContractExpanded) { + // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] + RaggedTensorToTensorOpModel model( + 2, // output_shape_dims + {9}, // values_shape + {{1}, {9}}, // partition_tensors_shapes + std::vector<std::string>({"FIRST_DIM_SIZE", "VALUE_ROWIDS"})); + model.InvokeFloat({3, 5}, // shape + {.1, .2, .3, .4, .5, .6, .7, .8, .9}, // values + 1.5, // default_value + std::vector<std::vector<int>>( + {std::vector<int>({4}), + std::vector<int>({0, 0, 0, 2, 2, 2, 2, 3, 3})})); + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 5})); + + EXPECT_THAT(model.GetOutputFloat(), + testing::ElementsAreArray({.1, .2, .3, 1.5, 1.5, // + 1.5, 1.5, 1.5, 1.5, 1.5, // + .4, .5, .6, .7, 1.5})); +} + +// Adds a dense dimension. +TEST(RaggedTensorToTensorTest, RaggedTensorToTensorContractExpandedDense) { + // params = [[.1, .2, .3], [], [.4, .5, .6, .7], [.8, .9]] + RaggedTensorToTensorOpModel model( + 3, // output_shape_dims + {9, 2}, // values_shape + {{1}, {9}}, // partition_tensors_shapes + std::vector<std::string>({"FIRST_DIM_SIZE", "VALUE_ROWIDS"})); + + model.InvokeFloat({3, 5, 2}, // shape + {.1, 1.1, .2, 1.2, .3, 1.3, .4, 1.4, .5, 1.5, .6, 1.6, .7, + 1.7, .8, 1.8, .9, 1.9}, // values + 1.5, // default_value + std::vector<std::vector<int>>( + {std::vector<int>({4}), + std::vector<int>({0, 0, 0, 2, 2, 2, 2, 3, 3})})); + + EXPECT_THAT(model.GetOutputShape(), testing::ElementsAreArray({3, 5, 2})); + EXPECT_THAT(model.GetOutputFloat(), + testing::ElementsAreArray( + {.1, 1.1, .2, 1.2, .3, 1.3, 1.5, 1.5, 1.5, 1.5, // + 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, // + .4, 1.4, .5, 1.5, .6, 1.6, .7, 1.7, 1.5, 1.5})); +} +} // namespace +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD new file mode 100644 index 00000000..e8df50f3 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/BUILD @@ -0,0 +1,389 @@ +# Memorymappable, WASM compilable, implementation of the encoder. +# + +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") +load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load(":native.bzl", "micore_tf_copts", "micore_tf_deps") + +package( + default_visibility = [ + "//tensorflow_lite_support:users", + ], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "testdata", + srcs = glob([ + "testdata/**", + ]), +) + +filegroup( + name = "config_fbs", + srcs = ["config.fbs"], +) + +flatbuffer_cc_library( + name = "config", + srcs = [ + "config.fbs", + ], +) + +flatbuffer_cc_library( + name = "encoder_config", + srcs = [ + "encoder_config.fbs", + ], + includes = [":config_fbs"], +) + +flatbuffer_cc_library( + name = "decoder_config", + srcs = [ + "decoder_config.fbs", + ], + includes = [":config_fbs"], +) + +cc_library( + name = "utils", + srcs = [ + ], + hdrs = [ + "utils.h", + ], +) + +cc_library( + name = "double_array_trie", + srcs = [ + ], + hdrs = [ + "double_array_trie.h", + ], + deps = [ + ":config", + ":utils", + ], +) + +cc_library( + name = "double_array_trie_builder", + srcs = [ + "double_array_trie_builder.cc", + ], + hdrs = [ + "double_array_trie_builder.h", + ], + deps = [ + ":config", + ":utils", + "@darts_clone", + ], +) + +cc_test( + name = "double_array_trie_test", + srcs = [ + "double_array_trie_test.cc", + ], + deps = [ + ":double_array_trie", + ":double_array_trie_builder", + ":encoder_config", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "sentencepiece_constants", + srcs = [], + hdrs = ["sentencepiece_constants.h"], +) + +cc_library( + name = "model_converter", + srcs = [ + "model_converter.cc", + ], + hdrs = [ + "model_converter.h", + ], + deps = [ + ":config", + ":decoder_config", + ":double_array_trie_builder", + ":encoder_config", + ":sentencepiece_constants", + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_sentencepiece//src:sentencepiece_model_cc_proto", + ], +) + +cc_library( + name = "optimized_encoder", + srcs = [ + "optimized_encoder.cc", + ], + hdrs = [ + "optimized_encoder.h", + ], + deps = [ + ":config", + ":double_array_trie", + ":encoder_config", + ], +) + +cc_library( + name = "optimized_decoder", + srcs = [ + "optimized_decoder.cc", + ], + hdrs = [ + "optimized_decoder.h", + ], + deps = [ + "config", + ":decoder_config", + ":double_array_trie", + ], +) + +cc_library( + name = "sentencepiece_tokenizer_h", + hdrs = [ + "sentencepiece_tokenizer.h", + ], +) + +cc_library( + name = "sentencepiece_detokenizer_h", + hdrs = [ + "sentencepiece_detokenizer.h", + ], +) + +cc_library( + name = "sentencepiece_tokenizer_op", + srcs = ["sentencepiece_tokenizer_op.cc"], + copts = micore_tf_copts(), + visibility = [ + "//visibility:public", + ], + deps = [ + ":sentencepiece_tokenizer_h", + ":optimized_encoder", + ] + micore_tf_deps(), + alwayslink = 1, +) + +cc_binary( + name = "sentencepiece_tokenizer_op.so", + srcs = [ + "sentencepiece_tokenizer_op.cc", + ], + copts = micore_tf_copts(), + linkshared = 1, + deps = [ + ":sentencepiece_tokenizer_h", + ":optimized_encoder", + ] + micore_tf_deps(), +) + +cc_library( + name = "sentencepiece_detokenizer_op", + srcs = ["sentencepiece_detokenizer_op.cc"], + copts = micore_tf_copts(), + visibility = [ + "//visibility:public", + ], + deps = [ + ":sentencepiece_detokenizer_h", + ":optimized_decoder", + ] + micore_tf_deps(), + alwayslink = 1, +) + +cc_binary( + name = "sentencepiece_detokenizer_op.so", + srcs = [ + "sentencepiece_detokenizer_op.cc", + ], + copts = micore_tf_copts(), + linkshared = 1, + deps = [ + ":sentencepiece_detokenizer_h", + ":optimized_decoder", + ] + micore_tf_deps(), +) + +cc_library( + name = "sentencepiece_tokenizer_tflite", + srcs = ["sentencepiece_tokenizer_tflite.cc"], + visibility = [ + "//visibility:public", + ], + deps = + [ + ":optimized_encoder", + ":sentencepiece_tokenizer_h", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_library( + name = "sentencepiece_detokenizer_tflite", + srcs = ["sentencepiece_detokenizer_tflite.cc"], + visibility = [ + "//visibility:public", + ], + deps = + [ + ":optimized_decoder", + ":sentencepiece_detokenizer_h", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_test( + name = "optimized_encoder_test", + srcs = [ + "optimized_encoder_test.cc", + ], + data = [ + ":testdata", + ], + deps = [ + ":double_array_trie_builder", + ":encoder_config", + ":model_converter", + ":optimized_encoder", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + "@com_google_sentencepiece//src:sentencepiece_cc_proto", + "@com_google_sentencepiece//src:sentencepiece_processor", + "@org_tensorflow//tensorflow/core:lib", + ], +) + +cc_test( + name = "optimized_decoder_test", + srcs = [ + "optimized_decoder_test.cc", + ], + data = [ + ":testdata", + ], + deps = [ + ":decoder_config", + ":double_array_trie_builder", + ":model_converter", + ":optimized_decoder", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + "@com_google_sentencepiece//src:sentencepiece_cc_proto", + "@com_google_sentencepiece//src:sentencepiece_processor", + "@org_tensorflow//tensorflow/core:lib", + ], +) + +cc_library( + name = "py_tflite_registerer", + srcs = ["py_tflite_registerer.cc"], + hdrs = ["py_tflite_registerer.h"], + deps = [ + ":sentencepiece_detokenizer_tflite", + ":sentencepiece_tokenizer_tflite", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +config_setting( + name = "armeabi_v7a_and_fastbuild", + values = { + "cpu": "armeabi-v7a", + "compilation_mode": "fastbuild", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "armeabi_v7a_and_dbg", + values = { + "cpu": "armeabi-v7a", + "compilation_mode": "dbg", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "android", + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "macos_i386", + values = { + "apple_platform_type": "macos", + "cpu": "darwin", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "macos_x86_64", + values = { + "apple_platform_type": "macos", + "cpu": "darwin_x86_64", + }, + visibility = ["//visibility:public"], +) + +alias( + name = "macos", + actual = select({ + ":macos_i386": ":macos_i386", + ":macos_x86_64": ":macos_x86_64", + "//conditions:default": ":macos_i386", # Arbitrarily chosen from above. + }), + visibility = ["//visibility:public"], +) + +config_setting( + name = "ios", + values = { + "crosstool_top": "@bazel_tools//tools/cpp:toolchain", + "apple_platform_type": "ios", + }, + visibility = ["//visibility:public"], +) + +alias( + name = "apple", + actual = select({ + ":macos": ":macos", + ":ios": ":ios", + "//conditions:default": ":ios", # Arbitrarily chosen from above. + }), + visibility = ["//visibility:public"], +) diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/config.fbs b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/config.fbs new file mode 100644 index 00000000..eba0bd8a --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/config.fbs @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +namespace tflite.ops.custom.sentencepiece; + +table Trie { + nodes: [uint32]; +} + + +enum EncoderVersion: byte { + SENTENCE_PIECE = 0, +} diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config.fbs b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config.fbs new file mode 100644 index 00000000..4a230ed9 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config.fbs @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "config.fbs"; + +namespace tflite.ops.custom.sentencepiece; + + +table DecoderConfig { + version: EncoderVersion = SENTENCE_PIECE; + + // The offset for encoding, usually used when codes with low codes are reserved + // for some special needs. + encoding_offset: int32; + + // A vector of strings that represent sentencepieces. + decode_pieces: [string]; + + // TODO(mgubin): Currently is not populated, haven't seen any Sentencepiece + // model with a denormalizer. + denormalized_prefixes: Trie; + denormalized_replacements: [byte]; + + // During encoding a dummy prefix (a whitespace) can be added to the input string, + // if this flag is true, this prefix will be removed. + remove_dummy_prefix: bool; + +} + + +root_type DecoderConfig; diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h new file mode 100644 index 00000000..547e0ea8 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h @@ -0,0 +1,120 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ + +#include <functional> +#include <vector> + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/config_generated.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +// A trie node specifies a node in the tree, either an intermediate node or +// a leaf node. +// A leaf node contains the id as an int of the string match. This id is encoded +// in the lower 31 bits, thus the number of distinct ids is 2^31. +// An intermediate node has an associated label and an offset to its children. +// The label is encoded in the least significant byte and must match the input +// character during matching. + +// A memory mappable trie, compatible with Darts::DoubleArray. +class DoubleArrayTrie { + public: + struct Match { + Match() {} + Match(int id, int match_length) : id(id), match_length(match_length) {} + int id = -1; + int match_length = -1; + bool empty() const { return match_length == -1; } + bool operator==(const Match& m) const { + return m.id == id && m.match_length == match_length; + } + }; + + // nodes and nodes_length specify the array of the nodes of the trie. + explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes) + : nodes_(nodes) {} + + // Finds matches that are prefixes of a string. + template <typename callback> + void IteratePrefixMatches(const utils::string_view& input, + callback update_fn) const; + + // Finds the longest prefix match of a string. + Match LongestPrefixMatch(const utils::string_view& input) const { + Match match; + IteratePrefixMatches(input, [&match](const Match& m) { match = m; }); + return match; + } + + private: + // Returns whether a node as a leaf as a child. + bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; } + + // Returns a value associated with a node. Available when a node is a leaf. + int value(uint32_t i) const { + return static_cast<int>(((*nodes_)[i]) & 0x7fffffff); + } + + // Returns a label associated with a node. + // A leaf node will have the MSB set and thus return an invalid label. + int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; } + + // Returns offset to children. + int32_t offset(uint32_t i) const { + const uint32_t node = (*nodes_)[i]; + return (node >> 10) << ((node & 0x200) >> 6); + } + + const flatbuffers::Vector<uint32_t>* nodes_; +}; + +template <typename callback> +void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input, + callback update_fn) const { + if (nodes_->size() == 0) { + return; + } + uint32_t pos = offset(0); + for (int i = 0; i < input.length(); ++i) { + pos ^= static_cast<unsigned char>(input.at(i)); + if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) { + // No match, exit. + return; + } + const bool node_has_leaf = has_leaf(pos); + pos ^= offset(pos); + if (pos < 0 || pos >= nodes_->size()) { + // We can get here only if the structure is corrupted. + return; + } + if (node_has_leaf) { + update_fn(Match(value(pos), i + 1)); + } + } +} + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc new file mode 100644 index 00000000..72b7262b --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.cc @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h" + +#include <algorithm> +#include <memory> + +#include "include/darts.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data) { + std::vector<int> ids; + ids.reserve(data.size()); + for (int i = 0; i < data.size(); ++i) { + ids.push_back(i); + } + return BuildTrie(data, ids); +} + +std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data, + const std::vector<int>& ids) { + // We make strong assumptions about binary structure of trie. + struct OneElement { + OneElement(const std::string* key_, int index_) + : key(key_), index(index_) {} + const std::string* key; + int index; + bool operator<(const OneElement& el) const { return *key < *el.key; } + }; + std::vector<OneElement> elements; + elements.reserve(data.size()); + auto data_iterator = std::begin(data); + auto ids_iterator = std::begin(ids); + for (; data_iterator != std::end(data) && ids_iterator != std::end(ids); + ++data_iterator, ++ids_iterator) { + elements.emplace_back(&(*data_iterator), *ids_iterator); + } + // Sort by keys. + std::sort(elements.begin(), elements.end()); + + // Create vectors to build the trie. + std::vector<const char*> strings; + std::vector<int32_t> indexes; + strings.reserve(data.size()); + indexes.reserve(data.size()); + for (const auto& el : elements) { + strings.push_back(el.key->c_str()); + indexes.push_back(el.index); + } + auto trie = std::make_unique<Darts::DoubleArray>(); + trie->build(data.size(), const_cast<char**>(&strings[0]), nullptr, + &indexes[0]); + // We make strong assumptions about internal Darts trie structure: + // - it is a vector of 32 bit signed integers + // - the "array" is the only one structure that contains all information about + // the trie. + const uint32_t* trie_data = static_cast<const uint32_t*>(trie->array()); + return std::vector<uint32_t>(trie_data, trie_data + trie->size()); +} + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h new file mode 100644 index 00000000..bc618abb --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ + +#include <string> +#include <vector> + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/config_generated.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data, + const std::vector<int>& ids); + +// A variant where ids are indexes in data. +std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data); + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_test.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_test.cc new file mode 100644 index 00000000..8a53d094 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_test.cc @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +TEST(DoubleArrayTrieTest, Match) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector<std::string> test_strings = {"A", "AAX", "AA", "B"}; + const auto trie_vector = builder.CreateVector(BuildTrie(test_strings)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto pieces = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_pieces(pieces); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + DoubleArrayTrie dat(config->pieces()->nodes()); + EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")), + DoubleArrayTrie::Match(2, 2)); + + std::vector<DoubleArrayTrie::Match> matches; + dat.IteratePrefixMatches( + utils::string_view("AAXL"), + [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); + EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1), + DoubleArrayTrie::Match(2, 2), + DoubleArrayTrie::Match(1, 3))); +} + +TEST(DoubleArrayTrieTest, ComplexMatch) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector<std::string> test_strings = {"\xe2\x96\x81the", ",", "s", + "\xe2\x96\x81Hello"}; + const std::vector<int> test_ids = {0, 5, 10, 15}; + const auto trie_vector = + builder.CreateVector(BuildTrie(test_strings, test_ids)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto pieces = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_pieces(pieces); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + DoubleArrayTrie dat(config->pieces()->nodes()); + + std::vector<DoubleArrayTrie::Match> matches; + dat.IteratePrefixMatches( + utils::string_view("\xe2\x96\x81Hello"), + [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); + EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8))); +} + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config.fbs b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config.fbs new file mode 100644 index 00000000..7f1f2bad --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config.fbs @@ -0,0 +1,52 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +include "config.fbs"; + +namespace tflite.ops.custom.sentencepiece; + +table EncoderConfig { + // Version of the encoder. + version: EncoderVersion = SENTENCE_PIECE; + start_code: int32 = 0; + end_code: int32 = 0; + + unknown_code: int32 = -1; + // Weight of "unknown code" when encoding. "Penalty" because it usually has a + // big negative weight,less than any other sentencepiece. + unknown_penalty: float = 0; + + // The offset for encoding, usually used when codes with low codes are reserved + // for some special needs. + encoding_offset: int32; + + // String pieces for encoding. + pieces: Trie; + pieces_scores: [float]; + + // Normalization related parameters. + remove_extra_whitespaces: bool; + + // Add a whitespace prefix before encoding. + add_dummy_prefix: bool; + + // Escape whitespaces during encoding so the decoder can restore them exactly as + // in the input. + escape_whitespaces: bool; + + // Normalization parameters. + normalized_prefixes: Trie; + normalized_replacements: [byte]; +} + +root_type EncoderConfig; diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc new file mode 100644 index 00000000..73e853ff --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.cc @@ -0,0 +1,197 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h" + +#include "absl/status/status.h" +#include "absl/strings/str_replace.h" +#include "src/sentencepiece_model.pb.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +std::tuple<std::vector<uint32_t>, std::vector<int8_t>> +DecodePrecompiledCharsmap( + const ::sentencepiece::NormalizerSpec& normalizer_spec) { + // This function "undoes" encoding done by + // sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap. + const char* precompiled_map = normalizer_spec.precompiled_charsmap().data(); + const uint32_t trie_size = + *reinterpret_cast<const uint32_t*>(precompiled_map); + const uint32_t* trie_ptr = + reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t)); + const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>( + precompiled_map + sizeof(uint32_t) + trie_size); + const int normalized_size = normalizer_spec.precompiled_charsmap().length() - + sizeof(uint32_t) - trie_size; + return std::make_tuple( + std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)), + std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size)); +} + +tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer( + const std::string& model_config_str, int encoding_offset) { + ::sentencepiece::ModelProto model_config; + if (!model_config.ParseFromString(model_config_str)) { + return absl::InvalidArgumentError( + "Invalid configuration, can't parse SentencePiece model config " + + model_config.InitializationErrorString()); + } + // Convert sentencepieces. + std::vector<std::string> pieces; + pieces.reserve(model_config.pieces_size()); + std::vector<float> scores; + scores.reserve(model_config.pieces_size()); + std::vector<int> ids; + ids.reserve(model_config.pieces_size()); + float min_score = 0.0; + int index = 0; + for (const auto& piece : model_config.pieces()) { + switch (piece.type()) { + case ::sentencepiece::ModelProto::SentencePiece::NORMAL: + case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED: + pieces.push_back(piece.piece()); + ids.push_back(index); + if (piece.score() < min_score) { + min_score = piece.score(); + } + break; + case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN: + case ::sentencepiece::ModelProto::SentencePiece::CONTROL: + // Ignore unknown and control codes. + break; + default: + return absl::InvalidArgumentError("Invalid SentencePiece piece type " + + piece.piece()); + } + scores.push_back(piece.score()); + ++index; + } + flatbuffers::FlatBufferBuilder builder(1024); + const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids)); + const auto pieces_score_vector = builder.CreateVector(scores); + TrieBuilder pieces_trie_builder(builder); + pieces_trie_builder.add_nodes(pieces_trie_vector); + const auto pieces_trie_fbs = pieces_trie_builder.Finish(); + + // Converting normalization. + const auto [normalization_trie, normalization_strings] = + DecodePrecompiledCharsmap(model_config.normalizer_spec()); + const auto normalization_trie_vector = + builder.CreateVector(normalization_trie); + TrieBuilder normalization_trie_builder(builder); + normalization_trie_builder.add_nodes(normalization_trie_vector); + const auto normalization_trie_fbs = normalization_trie_builder.Finish(); + const auto normalization_strings_fbs = + builder.CreateVector(normalization_strings); + + EncoderConfigBuilder ecb(builder); + ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE); + ecb.add_start_code(model_config.trainer_spec().bos_id()); + ecb.add_end_code(model_config.trainer_spec().eos_id()); + ecb.add_unknown_code(model_config.trainer_spec().unk_id()); + ecb.add_unknown_penalty(min_score - kUnkPenalty); + ecb.add_encoding_offset(encoding_offset); + ecb.add_pieces(pieces_trie_fbs); + ecb.add_pieces_scores(pieces_score_vector); + ecb.add_remove_extra_whitespaces( + model_config.normalizer_spec().remove_extra_whitespaces()); + ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix()); + ecb.add_escape_whitespaces( + model_config.normalizer_spec().escape_whitespaces()); + ecb.add_normalized_prefixes(normalization_trie_fbs); + ecb.add_normalized_replacements(normalization_strings_fbs); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); +} + +tflite::support::StatusOr<std::string> +ConvertSentencepieceModelToFlatBufferForDecoder( + const std::string& model_config_str, int encoding_offset) { + ::sentencepiece::ModelProto model_config; + if (!model_config.ParseFromString(model_config_str)) { + return absl::InvalidArgumentError( + "Invalid configuration, can't parse SentencePiece model config " + + model_config.InitializationErrorString()); + } + flatbuffers::FlatBufferBuilder builder(1024); + // Collect sentencepieces. + std::vector<std::string> pieces; + for (const auto& piece : model_config.pieces()) { + // In the original library all pieces processing is done during decoding. + // Because it is independent from context or parameters we can do it in + // advance here. + switch (piece.type()) { + case ::sentencepiece::ModelProto::SentencePiece::NORMAL: + case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED: + pieces.push_back( + absl::StrReplaceAll(piece.piece(), {{kSpaceSymbol, " "}})); + break; + case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN: + pieces.push_back( + kDefaultUnknownSymbol); // Always decode with the default unknown. + break; + default: + pieces.push_back(""); + } + } + const auto pieces_fbs = builder.CreateVectorOfStrings(pieces); + DecoderConfigBuilder decb(builder); + + decb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE); + decb.add_encoding_offset(encoding_offset); + decb.add_decode_pieces(pieces_fbs); + decb.add_remove_dummy_prefix( + model_config.normalizer_spec().add_dummy_prefix()); + + FinishDecoderConfigBuffer(builder, decb.Finish()); + return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), + builder.GetSize()); +} + +int GetVocabularySize(const std::string& model_string) { + const EncoderConfig* config = GetEncoderConfig(model_string.data()); + return config->pieces_scores()->size() + config->encoding_offset(); +} + +std::string ConvertSentencepieceModel(const std::string& model_string) { + const auto result = ConvertSentencepieceModelToFlatBuffer(model_string); + // TODO(mgubin): Propogate error to the Python code and throw correct + // exception. + assert(result.status().ok()); + return result.value(); +} + +std::string ConvertSentencepieceModelForDecoder( + const std::string& model_string) { + const auto result = + ConvertSentencepieceModelToFlatBufferForDecoder(model_string); + // TODO(mgubin): Propogate error to the Python code and throw correct + // exception. + assert(result.status().ok()); + return result.value(); +} + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h new file mode 100644 index 00000000..5687b628 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_MODEL_CONVERTER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_MODEL_CONVERTER_H_ +#include <string> + +#include "tensorflow_lite_support/cc/port/statusor.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +// Converts Sentencepiece configuration to flatbuffer format. +// encoding_offset is used by some encoders that combine different encodings. +tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer( + const std::string& model_config_str, int encoding_offset = 0); + +// Converts Sentencepiece configuration to flatbuffer format for encoder. +// encoding_offset is used by some encoders that combine different encodings. +tflite::support::StatusOr<std::string> +ConvertSentencepieceModelToFlatBufferForDecoder( + const std::string& model_config_str, int encoding_offset = 0); + +// The functions that are provided for the Python wrapper. +std::string ConvertSentencepieceModel(const std::string& model_string); +std::string ConvertSentencepieceModelForDecoder( + const std::string& model_string); + +// Returns size of a vocabulary from Sentencepiece configuration in flatbuffer +// format. +int GetVocabularySize(const std::string& model_string); + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_MODEL_CONVERTER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl new file mode 100644 index 00000000..87695a46 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl @@ -0,0 +1,86 @@ +"""Build definitions supporting platform-independent native build.""" + +load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_copts", "tf_opts_nortti_if_android") +load("@bazel_skylib//lib:selects.bzl", "selects") + +def micore_if(android, ios = [], default = []): + """Helper to create a select. + + Args: + android: what to return if compiling for Android. + ios: what to return if compiling for iOS. + default: what to return otherwise. + Returns: + the `android` list for Android compilation and the + `default` list otherwise. + """ + return select({ + ":android": android, + ":apple": ios, + "//conditions:default": default, + }) + +def micore_tf_copts(): + """C options for Tensorflow builds. + + Returns: + a list of copts which must be used by each cc_library which + refers to Tensorflow. Enables the library to compile both for + Android and for Linux. + """ + return tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [ + "-Wno-narrowing", + "-Wno-sign-compare", + "-Wno-overloaded-virtual", + ] + micore_if( + android = [ + # Set a define so Tensorflow's register_types.h + # adopts to support a rich set of types, to be pruned by + # selective registration. + "-DSUPPORT_SELECTIVE_REGISTRATION", + # Selective registration uses constexprs with recursive + # string comparisons; that can lead to compiler errors, so + # we increase the constexpr recursion depth. + "-fconstexpr-depth=1024", + ], + ) + selects.with_or({ + # If building for armeabi-v7a, and if compilation_mode is 'fastbuild' + # or 'dbg' then forcefully add -Oz to the list compiler options. + # Without it, some TF dependencies can't build (b/112286436). If + # compilation_mode is 'opt' then rely on the toolchain default. + ( + ":armeabi_v7a_and_fastbuild", + ":armeabi_v7a_and_dbg", + ): ["-Oz"], + "//conditions:default": [], + }) + +def micore_tf_deps(): + """Dependencies for Tensorflow builds. + + Returns: + list of dependencies which must be used by each cc_library + which refers to Tensorflow. Enables the library to compile both for + Android and for Linux. Use this macro instead of directly + declaring dependencies on Tensorflow. + """ + return micore_if( + android = [ + # Link to library which does not contain any ops. + "@org_tensorflow//tensorflow/core:portable_tensorflow_lib_lite", + "@gemmlowp//:eight_bit_int_gemm", + "@fft2d//:fft2d", + ], + ios = [ + "@org_tensorflow//tensorflow/core:portable_tensorflow_lib", + "@gemmlowp//:eight_bit_int_gemm", + "@fft2d//:fft2d", + ], + default = [ + # Standard references for Tensorflow when building for Linux. We use + # an indirection via the alias targets below, to facilitate whitelisting + # these deps in the mobile license presubmit checks. + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], + ) diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.cc new file mode 100644 index 00000000..86e186da --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.cc @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h" + +#include <string> +#include <tuple> + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +DecoderResult DecodeString(const std::vector<int>& encoded, + const void* config_buffer) { + DecoderResult result; + + // Get the config from the buffer. + const DecoderConfig* config = GetDecoderConfig(config_buffer); + if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) { + result.type = DecoderResultType::WRONG_CONFIG; + return result; + } + bool remove_dummy_prefix = config->remove_dummy_prefix(); + const auto config_pieces = config->decode_pieces(); + for (const auto code : encoded) { + const int real_code = code - config->encoding_offset(); + if (real_code >= config_pieces->size()) { + result.type = DecoderResultType::INVALID_INPUT; + return result; + } + const auto& piece_text = config_pieces->GetAsString(real_code); + const char* piece_str = piece_text->c_str(); + if (remove_dummy_prefix && *piece_str == ' ') { + ++piece_str; + } + result.decoded.append(piece_str); + remove_dummy_prefix = false; + } + // TODO(mgubin): Denormalize the string, haven't seen any Sentencepiece model + // with a denormalizer. + return result; +} + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h new file mode 100644 index 00000000..a4424687 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_DECODER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_DECODER_H_ + +// Sentencepiece decoder optimized with memmapped model. + +#include <string> +#include <vector> + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +enum class DecoderResultType { + SUCCESS = 0, + WRONG_CONFIG = 1, + INVALID_INPUT = 2 +}; + +struct DecoderResult { + DecoderResultType type = DecoderResultType::SUCCESS; + std::string decoded; +}; + +// Decodes one string from a vector of id. Takes the configuration as a +// type-erased buffer. +DecoderResult DecodeString(const std::vector<int>& encoded, + const void* config_buffer); + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_DECODER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc new file mode 100644 index 00000000..04d1c85a --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h" + +#include <fstream> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "src/sentencepiece.pb.h" +#include "src/sentencepiece_processor.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +namespace internal { + +tensorflow::Status TFReadFileToString(const std::string& filepath, + std::string* data) { + return tensorflow::ReadFileToString(tensorflow::Env::Default(), + /*test_path*/ filepath, data); +} + +absl::Status StdReadFileToString(const std::string& filepath, + std::string* data) { + std::ifstream infile(filepath); + if (!infile.is_open()) { + return absl::NotFoundError( + absl::StrFormat("Error when opening %s", filepath)); + } + std::string contents((std::istreambuf_iterator<char>(infile)), + (std::istreambuf_iterator<char>())); + data->append(contents); + infile.close(); + return absl::OkStatus(); +} + +} // namespace internal + +namespace { +static char kConfigFilePath[] = + "tensorflow_lite_support/custom_ops/kernel/" + "sentencepiece/testdata/sentencepiece.model"; + +TEST(OptimizedEncoder, ConfigConverter) { + std::string config; + auto status = internal::StdReadFileToString(kConfigFilePath, &config); + + ASSERT_TRUE(status.ok()); + + ::sentencepiece::SentencePieceProcessor processor; + ASSERT_OK(processor.LoadFromSerializedProto(config)); + const auto converted_model = ConvertSentencepieceModelForDecoder(config); + const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95"); + ::sentencepiece::SentencePieceText reference_encoded; + CHECK_OK(processor.Encode(test_string, &reference_encoded)); + + std::vector<int> encoded_vector; + encoded_vector.reserve(reference_encoded.pieces_size()); + for (const auto& piece : reference_encoded.pieces()) { + encoded_vector.push_back(piece.id()); + } + std::string ref_decoded; + ASSERT_OK(processor.Decode(encoded_vector, &ref_decoded)); + const auto decoded = DecodeString(encoded_vector, converted_model.data()); + ASSERT_EQ(decoded.type, DecoderResultType::SUCCESS); + ASSERT_EQ(ref_decoded, decoded.decoded); +} +} // namespace + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc new file mode 100644 index 00000000..5a59ee48 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.cc @@ -0,0 +1,239 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h" + +#include <algorithm> +#include <tuple> + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { +namespace { + +const char kSpaceSymbol[] = "\xe2\x96\x81"; + +template <typename processing_callback> +std::tuple<std::string, std::vector<int>> process_string( + const std::string& input, const std::vector<int>& offsets, + const processing_callback& pc) { + std::string result_string; + result_string.reserve(input.size()); + std::vector<int> result_offsets; + result_offsets.reserve(offsets.size()); + for (int i = 0, j = 0; i < input.size();) { + auto [consumed, new_string] = pc(input.data() + i, input.size() - i); + if (consumed == 0) { + // Skip the current byte and move forward. + result_string.push_back(input[i]); + result_offsets.push_back(offsets[j]); + i++; + j++; + continue; + } + result_string.append(new_string.data(), new_string.length()); + for (int i = 0; i < new_string.length(); ++i) { + result_offsets.push_back(offsets[j]); + } + j += consumed; + i += consumed; + } + return std::make_tuple(result_string, result_offsets); +} + +inline char is_whitespace(char c) { + return c == ' ' || c == '\t' || c == '\r' || c == '\n'; +} + +std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data, + int len) { + if (len == 0 || !is_whitespace(*data)) { + return std::make_tuple(0, utils::string_view(nullptr, 0)); + } + int num_consumed = 1; + for (; num_consumed < len && is_whitespace(data[num_consumed]); + ++num_consumed) { + } + return num_consumed > 1 + ? std::make_tuple(num_consumed, utils::string_view(" ", 1)) + : std::make_tuple(0, utils::string_view(nullptr, 0)); +} + +std::tuple<int, utils::string_view> find_replacement( + const char* data, int len, const DoubleArrayTrie& dat, + const flatbuffers::Vector<int8_t>& replacements) { + const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len)); + if (!max_match.empty()) { + // Because flatbuffer byte is signed char which is not the same as char, + // there is the reinterpret_cast here. + const char* replaced_string_ptr = + reinterpret_cast<const char*>(replacements.data() + max_match.id); + return std::make_tuple(max_match.match_length, + utils::string_view(replaced_string_ptr)); + } + return std::make_tuple(0, utils::string_view(nullptr, 0)); +} +} // namespace + +std::tuple<std::string, std::vector<int>> NormalizeString( + const std::string& in_string, const EncoderConfig& config) { + std::vector<int> output_offsets; + std::string result = in_string; + output_offsets.reserve(in_string.length()); + for (int i = 0; i < in_string.length(); ++i) { + output_offsets.push_back(i); + } + if (in_string.empty()) { + return std::make_tuple(result, output_offsets); + } + if (config.add_dummy_prefix()) { + result.insert(result.begin(), ' '); + output_offsets.insert(output_offsets.begin(), 0); + } + // Greedely replace normalized_prefixes with normalized_replacements + if (config.normalized_prefixes() != nullptr && + config.normalized_replacements() != nullptr) { + const DoubleArrayTrie normalized_prefixes_matcher( + config.normalized_prefixes()->nodes()); + const auto norm_replace = [&config, &normalized_prefixes_matcher]( + const char* data, int len) { + return find_replacement(data, len, normalized_prefixes_matcher, + *config.normalized_replacements()); + }; + std::tie(result, output_offsets) = + process_string(result, output_offsets, norm_replace); + } + if (config.remove_extra_whitespaces()) { + std::tie(result, output_offsets) = + process_string(result, output_offsets, remove_extra_whitespaces); + if (!result.empty() && is_whitespace(result.back())) { + result.pop_back(); + output_offsets.pop_back(); + } + } + if (config.escape_whitespaces()) { + const auto replace_whitespaces = [](const char* data, int len) { + if (len > 0 && is_whitespace(*data)) { + return std::make_tuple(1, utils::string_view(kSpaceSymbol)); + } + return std::make_tuple(0, utils::string_view(nullptr, 0)); + }; + std::tie(result, output_offsets) = + process_string(result, output_offsets, replace_whitespaces); + } + + return std::make_tuple(result, output_offsets); +} + +EncoderResult EncodeNormalizedString(const std::string& str, + const std::vector<int>& offsets, + const EncoderConfig& config, bool add_bos, + bool add_eos, bool reverse) { + const DoubleArrayTrie piece_matcher(config.pieces()->nodes()); + const flatbuffers::Vector<float>* piece_scores = config.pieces_scores(); + const int unknown_code = config.unknown_code(); + const float unknown_penalty = config.unknown_penalty(); + struct LatticeElement { + float score = 0; + int code = -1; + int prev_position = -1; + LatticeElement(float score_, int code_, int prev_position_) + : score(score_), code(code_), prev_position(prev_position_) {} + LatticeElement() {} + }; + const int length = str.length(); + std::vector<LatticeElement> lattice(length + 1); + for (int i = 0; i < length; ++i) { + if (i > 0 && lattice[i].prev_position < 0) { + // This state is unreachable. + continue; + } + if (unknown_code >= 0) { + // Put unknown code. + const float penalized_score = lattice[i].score + unknown_penalty; + const int pos = i + 1; + LatticeElement& current_element = lattice[pos]; + if (current_element.prev_position < 0 || + current_element.score < penalized_score) { + current_element = LatticeElement( + penalized_score, unknown_code, + // If the current state is already reached by unknown code, merge + // states. + lattice[i].code == unknown_code ? lattice[i].prev_position : i); + } + } + auto lattice_update = [&lattice, i, + piece_scores](const DoubleArrayTrie::Match& m) { + LatticeElement& target_element = lattice[i + m.match_length]; + const float score = lattice[i].score + (*piece_scores)[m.id]; + if (target_element.prev_position < 0 || target_element.score < score) { + target_element = LatticeElement(score, m.id, i); + } + }; + piece_matcher.IteratePrefixMatches( + utils::string_view(str.data() + i, length - i), lattice_update); + } + + EncoderResult result; + if (add_eos) { + result.codes.push_back(config.end_code()); + result.offsets.push_back(length); + } + if (lattice[length].prev_position >= 0) { + for (int pos = length; pos > 0;) { + auto code = lattice[pos].code; + if (code != config.unknown_code()) { + code += config.encoding_offset(); + } + result.codes.push_back(code); + pos = lattice[pos].prev_position; + result.offsets.push_back(offsets[pos]); + } + } + if (add_bos) { + result.codes.push_back(config.start_code()); + result.offsets.push_back(0); + } + if (!reverse) { + std::reverse(result.codes.begin(), result.codes.end()); + std::reverse(result.offsets.begin(), result.offsets.end()); + } + return result; +} + +EncoderResult EncodeString(const std::string& string, const void* config_buffer, + bool add_bos, bool add_eos, bool reverse) { + // Get the config from the buffer. + const EncoderConfig* config = GetEncoderConfig(config_buffer); + if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) { + EncoderResult result; + result.type = EncoderResultType::WRONG_CONFIG; + return result; + } + std::string normalized_string; + std::vector<int> offsets; + std::tie(normalized_string, offsets) = NormalizeString(string, *config); + return EncodeNormalizedString(normalized_string, offsets, *config, add_bos, + add_eos, reverse); +} + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h new file mode 100644 index 00000000..44d6e88f --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ + +// Sentencepiece encoder optimized with memmapped model. + +#include <string> +#include <tuple> +#include <vector> + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 }; + +struct EncoderResult { + EncoderResultType type = EncoderResultType::SUCCESS; + std::vector<int> codes; + std::vector<int> offsets; +}; +std::tuple<std::string, std::vector<int>> NormalizeString( + const std::string& in_string, const EncoderConfig& config); + +// Encodes one string and returns ids and offsets. Takes the configuration as a +// type-erased buffer. +EncoderResult EncodeString(const std::string& string, const void* config_buffer, + bool add_bos, bool add_eos, bool reverse); + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc new file mode 100644 index 00000000..ad3cd27f --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h" + +#include <fstream> + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h" +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "src/sentencepiece.pb.h" +#include "src/sentencepiece_processor.h" +#include "tensorflow/core/platform/env.h" + + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +namespace internal { + +tensorflow::Status TFReadFileToString( + const std::string& filepath, std::string* data) { + return tensorflow::ReadFileToString( + tensorflow::Env::Default(), /*test_path*/ filepath, data); +} + +absl::Status StdReadFileToString( + const std::string& filepath, std::string* data) { + std::ifstream infile(filepath); + if (!infile.is_open()) { + return absl::NotFoundError( + absl::StrFormat("Error when opening %s", filepath)); + } + std::string contents((std::istreambuf_iterator<char>(infile)), + (std::istreambuf_iterator<char>())); + data->append(contents); + infile.close(); + return absl::OkStatus(); +} +} // namespace internal + +namespace { + +static char kConfigFilePath[] = + "tensorflow_lite_support/custom_ops/kernel/" + "sentencepiece/testdata/sentencepiece.model"; + +TEST(OptimizedEncoder, NormalizeStringWhitestpaces) { + flatbuffers::FlatBufferBuilder builder(1024); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(true); + ecb.add_add_dummy_prefix(true); + ecb.add_escape_whitespaces(true); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto [res_string, offsets] = NormalizeString("x y", *config); + EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); + EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3)); + } + { + const auto [res_string, offsets] = NormalizeString("\tx y\n", *config); + EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); + EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4)); + } +} + +TEST(OptimizedEncoder, NormalizeStringReplacement) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"}; + const char norm_replacements[] = "A1\0A2\0A3\0A4"; + const auto trie_vector = + builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9})); + const auto norm_r = builder.CreateVector<int8_t>( + reinterpret_cast<const signed char*>(norm_replacements), + sizeof(norm_replacements)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto norm_p = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(false); + ecb.add_normalized_prefixes(norm_p); + ecb.add_normalized_replacements(norm_r); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto [res_string, offsets] = + NormalizeString("ABAABAAABAAAA", *config); + EXPECT_EQ(res_string, "A1BA2BA3BA4"); + EXPECT_THAT(offsets, + ::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9)); + } +} + +TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA", + "X"}; + const char norm_replacements[] = "A1\0A2\0A3\0A4\0 "; + const auto trie_vector = + builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12})); + const auto norm_r = builder.CreateVector<int8_t>( + reinterpret_cast<const signed char*>(norm_replacements), + sizeof(norm_replacements)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto norm_p = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(true); + ecb.add_normalized_prefixes(norm_p); + ecb.add_normalized_replacements(norm_r); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto [res_string, offsets] = + NormalizeString("XXABAABAAABAAAA", *config); + EXPECT_EQ(res_string, " A1BA2BA3BA4"); + EXPECT_THAT(offsets, + ::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11)); + } +} + +TEST(OptimizedEncoder, ConfigConverter) { + std::string config; + auto status = internal::StdReadFileToString(kConfigFilePath, &config); + ASSERT_TRUE(status.ok()); + + ::sentencepiece::SentencePieceProcessor processor; + ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok()); + const auto converted_model = ConvertSentencepieceModel(config); + const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95"); + const auto encoded = + EncodeString(test_string, converted_model.data(), false, false, false); + ASSERT_EQ(encoded.codes.size(), encoded.offsets.size()); + + ::sentencepiece::SentencePieceText reference_encoded; + ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok()); + EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size()); + for (int i = 0; i < encoded.codes.size(); ++i) { + EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id()); + EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin()); + } +} + +} // namespace +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.cc new file mode 100644 index 00000000..5345409f --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.cc @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h" + +namespace tflite { +namespace ops { +namespace custom { +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER(); +TfLiteRegistration* Register_SENTENCEPIECE_DETOKENIZER(); +} // namespace custom +} // namespace ops +} // namespace tflite + +extern "C" void TFLite_SentencepieceTokenizerRegisterer( + tflite::MutableOpResolver* resolver) { + resolver->AddCustom("TFSentencepieceTokenizeOp", + tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER()); + resolver->AddCustom( + "TFSentencepieceDetokenizeOp", + tflite::ops::custom::Register_SENTENCEPIECE_DETOKENIZER()); +} diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h new file mode 100644 index 00000000..deb4e4ee --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/py_tflite_registerer.h @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ +#include "tensorflow/lite/mutable_op_resolver.h" + +// C-function that is called from the Python Wrapper. + +extern "C" void TFLite_SentencepieceTokenizerRegisterer( + tflite::MutableOpResolver *resolver); + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_PY_TFLITE_REGISTERER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h new file mode 100644 index 00000000..55644ba6 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +// The constant is copied from +// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc +constexpr float kUnkPenalty = 10.0; + +// These constants are copied from +// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc +// +// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK). +constexpr char kSpaceSymbol[] = "\xe2\x96\x81"; + +// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK), +// since this character can be useful both for user and +// developer. We can easily figure out that <unk> is emitted. +constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 "; + +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h new file mode 100644 index 00000000..1f4e0f4d --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_ + +// Constants are shared between TF and TFLite SentencepieceTokenizer kernels. +namespace tensorflow { +namespace ops { +constexpr int kSPModelIndex = 0; +constexpr int kInputIndex = 1; +constexpr int kInputSplits = 2; +constexpr int kAddBOSInput = 4; +constexpr int kAddEOSInput = 5; +constexpr int kReverseInput = 6; +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_DETOKENIZER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_op.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_op.cc new file mode 100644 index 00000000..bd4b5a17 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_op.cc @@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h" + +namespace tensorflow { +namespace ops { +REGISTER_OP("TFSentencepieceDetokenizeOp") + .Input("sp_model: uint8") + .Input("input_values: int32") + .Input("input_splits: Tsplits") + .Attr("Tsplits: {int32, int64} = DT_INT64") + .Output("output: string") + .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + + shape_inference::DimensionHandle dim; + TF_RETURN_IF_ERROR(c->Subtract(c->NumElements(c->input(2)), 1, &dim)); + c->set_output(0, c->Vector(dim)); + return Status::OK(); + }); + +template <typename Tsplits> +class TFSentencepieceDetokenizerOp : public tensorflow::OpKernel { + public: + explicit TFSentencepieceDetokenizerOp(tensorflow::OpKernelConstruction* ctx) + : OpKernel(ctx) {} + void Compute(tensorflow::OpKernelContext* ctx) override { + const auto& model_tensor = ctx->input(kSPModelIndex); + const auto& input_values_tensor = ctx->input(kInputIndex); + const auto input_values_flat = + input_values_tensor.flat<tensorflow::int32>(); + const auto& input_splits_tensor = ctx->input(kInputSplits); + const auto input_splits_flat = input_splits_tensor.flat<Tsplits>(); + const int num_of_sentences = input_splits_flat.size() - 1; + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, {num_of_sentences}, &output_tensor)); + auto output_flat = output_tensor->flat<tensorflow::tstring>(); + std::vector<int> codes_for_split; + int input_offset = 0; + for (int i = 0; i < num_of_sentences; i++) { + // Create a vector of int32 from input according to spans. + const int split_size = input_splits_flat(i + 1) - input_splits_flat(i); + codes_for_split.clear(); + codes_for_split.reserve(split_size); + for (int j = 0; j < split_size; ++j) { + codes_for_split.push_back(input_values_flat(input_offset++)); + } + const auto res = tflite::ops::custom::sentencepiece::DecodeString( + codes_for_split, model_tensor.data()); + OP_REQUIRES( + ctx, + res.type == + tflite::ops::custom::sentencepiece::DecoderResultType::SUCCESS, + tensorflow::Status(tensorflow::error::INTERNAL, + "Sentencepiece conversion failed")); + output_flat(i) = res.decoded; + } + } +}; +} // namespace ops +} // namespace tensorflow + +REGISTER_KERNEL_BUILDER( + Name("TFSentencepieceDetokenizeOp") + .Device(tensorflow::DEVICE_CPU) + .TypeConstraint<tensorflow::int32>("Tsplits"), + tensorflow::ops::TFSentencepieceDetokenizerOp<tensorflow::int32>); +REGISTER_KERNEL_BUILDER( + Name("TFSentencepieceDetokenizeOp") + .Device(tensorflow::DEVICE_CPU) + .TypeConstraint<tensorflow::int64>("Tsplits"), + tensorflow::ops::TFSentencepieceDetokenizerOp<tensorflow::int64>); diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc new file mode 100644 index 00000000..54b34e4e --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer_tflite.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/** + * Sentencepiece tflite detokenizer implementation. + */ +#include <algorithm> +#include <iterator> + +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_decoder.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_detokenizer.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { +namespace detokenizer { + +constexpr int kOutputValuesInd = 0; +// Initializes text encoder object from serialized parameters. +void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, + size_t /*length*/) { + return nullptr; +} +void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // TODO(mgubin): Add checks for input and output tensors. + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + SetTensorToDynamic(&output_values); + // TODO(mgubin): Check input types. + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor& model_tensor = + context->tensors[node->inputs->data[tensorflow::ops::kSPModelIndex]]; + const auto model_buffer_data = model_tensor.data.data; + const TfLiteTensor& input_encoded = + context->tensors[node->inputs->data[tensorflow::ops::kInputIndex]]; + const int32_t* input_encoded_data = input_encoded.data.i32; + const TfLiteTensor& input_splits = + context->tensors[node->inputs->data[tensorflow::ops::kInputSplits]]; + const int num_of_sentences = NumElements(input_splits.dims) - 1; + const int32_t* input_splits_data = input_splits.data.i32; + + DynamicBuffer buf; + + std::vector<int> codes_for_split; + int input_offset = 0; + for (int i = 0; i < num_of_sentences; i++) { + // Create a vector of int32 from input according to spans. + const int split_size = input_splits_data[i + 1] - input_splits_data[i]; + codes_for_split.clear(); + std::copy(input_encoded_data + input_offset, + input_encoded_data + input_offset + split_size, + std::back_inserter(codes_for_split)); + const auto res = DecodeString(codes_for_split, model_buffer_data); + TF_LITE_ENSURE_MSG(context, res.type == DecoderResultType::SUCCESS, + "Sentencepiece decoding failed"); + buf.AddString(res.decoded.data(), res.decoded.length()); + input_offset += split_size; + } + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + buf.WriteToTensor(&output_values, nullptr); + return kTfLiteOk; +} +} // namespace detokenizer +} // namespace sentencepiece + +TfLiteRegistration* Register_SENTENCEPIECE_DETOKENIZER() { + static TfLiteRegistration r = { + sentencepiece::detokenizer::Initialize, sentencepiece::detokenizer::Free, + sentencepiece::detokenizer::Prepare, sentencepiece::detokenizer::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h new file mode 100644 index 00000000..cb3ee07f --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_ + +// Constants are shared between TF and TFLite SentencepieceTokenizer kernels. +namespace tensorflow { +namespace ops { + +constexpr int kSPModelIndex = 0; +constexpr int kInputIndex = 1; +constexpr int kAddBOSInput = 4; +constexpr int kAddEOSInput = 5; +constexpr int kReverseInput = 6; +} // namespace ops +} // namespace tensorflow + +#endif // sTENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc new file mode 100644 index 00000000..41fc5aa2 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_op.cc @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <iterator> +#include <vector> + +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" + +namespace tensorflow { +namespace ops{ + +// copied from third_party/tensorflow_text/core/ops/sentencepiece_ops.cc +REGISTER_OP("TFSentencepieceTokenizeOp") + .Input("sp_model: uint8") + .Input("input: string") + .Input("nbest_size: int32") + .Input("alpha: float") + .Input("add_bos: bool") + .Input("add_eos: bool") + .Input("reverse: bool") + .Attr("out_type: {int32, string} = DT_INT32") + .Attr("Tsplits: {int32, int64} = DT_INT32") + .Output("output_values: out_type") + .Output("output_splits: Tsplits") + .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { + tensorflow::shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + + c->set_output( + 0, c->Vector( + tensorflow::shape_inference::InferenceContext::kUnknownDim)); + + tensorflow::shape_inference::DimensionHandle num_splits; + TF_RETURN_IF_ERROR(c->Add(c->NumElements(c->input(1)), 1, &num_splits)); + c->set_output(1, c->Vector(num_splits)); + return tensorflow::Status::OK(); + }); + +class TFSentencepieceOp : public tensorflow::OpKernel { + public: + explicit TFSentencepieceOp(tensorflow::OpKernelConstruction* ctx) + : OpKernel(ctx) {} + void Compute(tensorflow::OpKernelContext* ctx) override { + const auto& model_tensor = ctx->input(kSPModelIndex); + const auto& input_values_tensor = ctx->input(kInputIndex); + const auto input_values_flat = + input_values_tensor.flat<tensorflow::tstring>(); + const int num_of_input_values = input_values_flat.size(); + + const auto& add_bos_tensor = ctx->input(kAddBOSInput); + const bool add_bos = add_bos_tensor.scalar<bool>()(); + const auto& add_eos_tensor = ctx->input(kAddEOSInput); + const bool add_eos = add_eos_tensor.scalar<bool>()(); + const auto& reverse_tensor = ctx->input(kReverseInput); + const bool reverse = reverse_tensor.scalar<bool>()(); + + std::vector<int32> encoded; + std::vector<int32> splits; + for (int i = 0; i < num_of_input_values; ++i) { + const auto res = ::tflite::ops::custom::sentencepiece::EncodeString( + input_values_flat(i), model_tensor.data(), add_bos, add_eos, reverse); + OP_REQUIRES( + ctx, + res.type == + ::tflite::ops::custom::sentencepiece::EncoderResultType::SUCCESS, + tensorflow::Status(tensorflow::error::INTERNAL, + "Sentencepiece conversion failed")); + std::copy(res.codes.begin(), res.codes.end(), + std::back_inserter(encoded)); + splits.emplace_back(encoded.size()); + } + tensorflow::Tensor* output_values_tensor = nullptr; + tensorflow::Tensor* output_splits_tensor = nullptr; + + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, {encoded.size()}, &output_values_tensor)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {splits.size() + 1}, + &output_splits_tensor)); + + auto values_tensor_flat = output_values_tensor->vec<int32>(); + auto splits_tensor_flat = output_splits_tensor->vec<int32>(); + for (int i = 0; i < encoded.size(); ++i) { + values_tensor_flat(i) = encoded[i]; + } + splits_tensor_flat(0) = 0; + for (int i = 0; i < splits.size(); ++i) { + splits_tensor_flat(i + 1) = splits[i]; + } + } +}; + +} // namespace ops +} // namespace tensorflow +REGISTER_KERNEL_BUILDER( + Name("TFSentencepieceTokenizeOp").Device(tensorflow::DEVICE_CPU), + tensorflow::ops::TFSentencepieceOp); diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc new file mode 100644 index 00000000..8309a6a2 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer_tflite.cc @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * Sentencepiece tflite tokenizer implementation. + */ +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/optimized_encoder.h" +#include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_tokenizer.h" +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { +namespace tokenizer { + +constexpr int kOutputValuesInd = 0; +constexpr int kOutputSplitsInd = 1; + +namespace { +TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) { + TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size()); + int index = 0; + for (const int size : sizes) { + array_size->data[index++] = size; + } + return array_size; +} +} // namespace + +// Initializes text encoder object from serialized parameters. +void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, + size_t /*length*/) { + return nullptr; +} +void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // TODO(mgubin): Add checks for input and output tensors. + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + SetTensorToDynamic(&output_values); + + TfLiteTensor& output_splits = + context->tensors[node->outputs->data[kOutputSplitsInd]]; + SetTensorToDynamic(&output_splits); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor& model_tensor = + context->tensors[node->inputs->data[tensorflow::ops::kSPModelIndex]]; + const auto model_buffer_data = model_tensor.data.data; + const TfLiteTensor& input_text = + context->tensors[node->inputs->data[tensorflow::ops::kInputIndex]]; + + const TfLiteTensor add_bos_tensor = + context->tensors[node->inputs->data[tensorflow::ops::kAddBOSInput]]; + const bool add_bos = add_bos_tensor.data.b[0]; + const TfLiteTensor add_eos_tensor = + context->tensors[node->inputs->data[tensorflow::ops::kAddEOSInput]]; + const bool add_eos = add_eos_tensor.data.b[0]; + const TfLiteTensor reverse_tensor = + context->tensors[node->inputs->data[tensorflow::ops::kReverseInput]]; + const bool reverse = reverse_tensor.data.b[0]; + + std::vector<int32> encoded; + std::vector<int32> splits; + const int num_strings = tflite::GetStringCount(&input_text); + for (int i = 0; i < num_strings; ++i) { + const auto strref = tflite::GetString(&input_text, i); + const auto res = EncodeString(std::string(strref.str, strref.len), + model_buffer_data, add_bos, add_eos, reverse); + TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS, + "Sentencepiece conversion failed"); + std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded)); + splits.emplace_back(encoded.size()); + } + + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor( + context, &output_values, + CreateSizeArray({static_cast<int>(encoded.size())}))); + int32_t* output_values_flat = output_values.data.i32; + std::copy(encoded.begin(), encoded.end(), output_values_flat); + TfLiteTensor& output_splits = + context->tensors[node->outputs->data[kOutputSplitsInd]]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output_splits, + CreateSizeArray({static_cast<int>(splits.size() + 1)}))); + int32_t* output_splits_flat = output_splits.data.i32; + *output_splits_flat = 0; + std::copy(splits.begin(), splits.end(), output_splits_flat + 1); + return kTfLiteOk; +} +} // namespace tokenizer +} // namespace sentencepiece + +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() { + static TfLiteRegistration r = { + sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free, + sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/testdata/sentencepiece.model b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/testdata/sentencepiece.model Binary files differnew file mode 100644 index 00000000..041188ff --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/testdata/sentencepiece.model diff --git a/tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h new file mode 100644 index 00000000..13bc021e --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_UTILS_H_ + +#include <ostream> +#include <string> + +namespace tflite { +namespace ops { +namespace custom { +namespace sentencepiece { + +// AOSP and WASM doesn't support string_view, +// we put here a minimal re-implementation. +namespace utils { + +class string_view { + public: + explicit string_view(const std::string& s) + : str_(s.data()), len_(s.length()) {} + string_view(const char* str, int len) : str_(str), len_(len) {} + // A constructor from c string. + explicit string_view(const char* s) : str_(s), len_(strlen(s)) {} + + int length() const { return len_; } + const char* data() const { return str_; } + bool empty() const { return len_ == 0; } + unsigned char at(int i) const { return str_[i]; } + + private: + const char* str_ = nullptr; + const int len_ = 0; +}; + +inline std::ostream& operator<<(std::ostream& os, const string_view& sv) { + os << std::string(sv.data(), sv.length()); + return os; +} +inline bool operator==(const string_view& view1, const string_view& view2) { + if (view1.length() != view2.length()) { + return false; + } + return memcmp(view1.data(), view2.data(), view1.length()) == 0; +} + +} // namespace utils +} // namespace sentencepiece +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_UTILS_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_flex_delegate.tflite b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_flex_delegate.tflite Binary files differnew file mode 100644 index 00000000..dc3b78b2 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_flex_delegate.tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_1d_input.tflite b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_1d_input.tflite Binary files differnew file mode 100644 index 00000000..03640e28 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_1d_input.tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_2d_input.tflite b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_2d_input.tflite Binary files differnew file mode 100644 index 00000000..b6883745 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_ragged_2d_input.tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_tensor.tflite b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_tensor.tflite Binary files differnew file mode 100644 index 00000000..88e5cef5 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/testdata/whitespace_tokenizer_to_tensor.tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc new file mode 100644 index 00000000..dad2f000 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.cc @@ -0,0 +1,224 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h" + +#include <algorithm> +#include <utility> +#include <vector> + +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" +#include "libutf/utf.h" + +constexpr int kInput = 0; +constexpr int kOutputValues = 0; +constexpr int kOutputRowSplitsStart = 1; + +namespace tflite { +namespace ops { +namespace custom { +namespace whitespace_tokenizer { + +// This TFLite op implements a whitespace tokenizer, and can output the +// tokens as either a padded tensor or a ragged tensor. +// +// If we're outputting a padded tensor, our outputs are: +// * A string tensor +// +// If we're outputting a ragged tensor, our outputs are: +// * A string tensor (the innermost values of the ragged tensor) +// * N int64 tensors (the row_splits of the ragged tensor, where N is the +// rank of the input tensor) + +inline bool OutputIsPaddedTensor(TfLiteNode* node) { + return NumOutputs(node) == 1; +} + +inline int charntorune(Rune* r, const char* s, int n) { + const int bytes_read = chartorune(r, const_cast<char *>(s)); + if (bytes_read > n) { + *r = Runeerror; + return 0; + } + return bytes_read; +} + +std::vector<std::pair<const char*, int>> Tokenize(StringRef str) { + const char* p = str.str; + int n = str.len; + + std::vector<std::pair<const char*, int>> tokens; + const char* start = nullptr; + while (n > 0) { + Rune r; + int c = charntorune(&r, p, n); + if (r == Runeerror) break; + + if (isspacerune(r)) { + if (start != nullptr) { + tokens.push_back({start, p - start}); + } + start = nullptr; + } else { + if (start == nullptr) { + start = p; + } + } + + p += c; + n -= c; + } + if (start != nullptr) { + tokens.push_back({start, p - start}); + } + + return tokens; +} + +TfLiteStatus WritePaddedOutput( + const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens, + const TfLiteTensor* input, TfLiteTensor* output_values) { + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) + 1); + for (int i = 0; i < NumDimensions(input); ++i) { + output_shape->data[i] = SizeOfDimension(input, i); + } + + size_t max_tokens = 0; + for (const auto& tokens : list_of_tokens) { + max_tokens = std::max(max_tokens, tokens.size()); + } + + output_shape->data[NumDimensions(input)] = max_tokens; + DynamicBuffer buffer; + for (const auto& tokens : list_of_tokens) { + for (const auto& token : tokens) { + buffer.AddString(token.first, token.second); + } + for (int i = tokens.size(); i < max_tokens; ++i) { + buffer.AddString(nullptr, 0); + } + } + buffer.WriteToTensor(output_values, output_shape); + return kTfLiteOk; +} + +TfLiteStatus WriteRaggedOutput( + const std::vector<std::vector<std::pair<const char*, int>>>& list_of_tokens, + const TfLiteTensor* input, TfLiteTensor* output_values, + std::vector<TfLiteTensor*> nested_row_splits) { + // The outer dimensions of the ragged tensor are all non-ragged. + for (int i = 0; i < nested_row_splits.size() - 1; ++i) { + int row_splits_step = SizeOfDimension(input, i + 1); + TfLiteTensor* row_splits = nested_row_splits[i]; + for (int j = 0; j < SizeOfDimension(row_splits, 0); ++j) { + row_splits->data.i64[j] = j * row_splits_step; + } + } + + // Generate the innermost row_splits and values tensors. + TfLiteTensor* row_splits = nested_row_splits.back(); + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(1); + DynamicBuffer buffer; + int token_index = 0; + int row_splits_index = 0; + for (const auto& tokens : list_of_tokens) { + row_splits->data.i64[row_splits_index] = token_index; + for (const auto& token : tokens) { + buffer.AddString(token.first, token.second); + ++token_index; + } + ++row_splits_index; + } + row_splits->data.i64[row_splits_index] = token_index; + output_shape->data[0] = token_index; + buffer.WriteToTensor(output_values, output_shape); + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + SetTensorToDynamic(output_values); + + if (OutputIsPaddedTensor(node)) { + return kTfLiteOk; + } + + const TfLiteTensor* input = GetInput(context, node, kInput); + TF_LITE_ENSURE(context, NumDimensions(input) == + (NumOutputs(node) - kOutputRowSplitsStart)); + + // Resize the row_splits tensors. We're just adding a ragged inner + // dimension to the shape of the input tensor, so the size of the + // row_splits tensors can be calculated using the input tensor's shape. + int input_size = 1; + for (int i = 0; i < NumDimensions(input); ++i) { + input_size *= SizeOfDimension(input, i); + + TfLiteIntArray* row_splits_shape = TfLiteIntArrayCreate(1); + row_splits_shape->data[0] = input_size + 1; + TfLiteTensor* row_splits = + GetOutput(context, node, kOutputRowSplitsStart + i); + TF_LITE_ENSURE_STATUS( + context->ResizeTensor(context, row_splits, row_splits_shape)); + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInput); + int input_size = 1; + for (int i = 0; i < NumDimensions(input); ++i) { + input_size *= SizeOfDimension(input, i); + } + + std::vector<std::vector<std::pair<const char*, int>>> list_of_tokens; + list_of_tokens.reserve(input_size); + for (int i = 0; i < input_size; ++i) { + list_of_tokens.emplace_back(Tokenize(GetString(input, i))); + } + + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + TF_LITE_ENSURE(context, IsDynamicTensor(output_values)); + + if (OutputIsPaddedTensor(node)) { + return WritePaddedOutput(list_of_tokens, input, output_values); + } + + std::vector<TfLiteTensor*> nested_row_splits; + nested_row_splits.reserve(NumDimensions(input)); + for (int i = 0; i < NumDimensions(input); ++i) { + TfLiteTensor* output_row_splits = + GetOutput(context, node, kOutputRowSplitsStart + i); + nested_row_splits.push_back(output_row_splits); + } + return WriteRaggedOutput(list_of_tokens, input, output_values, + nested_row_splits); +} + +} // namespace whitespace_tokenizer + +TfLiteRegistration* Register_tftext_WhitespaceTokenizer() { + static TfLiteRegistration r = {nullptr, nullptr, + whitespace_tokenizer::Prepare, + whitespace_tokenizer::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h new file mode 100644 index 00000000..b1902480 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_H_ + +#include "tensorflow/lite/context.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_tftext_WhitespaceTokenizer(); + +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc new file mode 100644 index 00000000..534fbef4 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h" + +#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h" +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace tflite { +namespace ops { +namespace custom { + +void AddWhitespaceTokenizerCustomOp(MutableOpResolver* resolver) { + resolver->AddCustom("tftext:WhitespaceTokenizer", + Register_tftext_WhitespaceTokenizer()); +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h new file mode 100644 index 00000000..4f57d8d8 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_OP_RESOLVER_H_ +#define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_OP_RESOLVER_H_ + +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace tflite { +namespace ops { +namespace custom { + +// Adds the WhitespaceTokenizer custom op to an op resolver. +// This function can be loaded using dlopen. Since C++ function names get +// mangled, declare this function as extern C, so its name is unchanged. +extern "C" void AddWhitespaceTokenizerCustomOp(MutableOpResolver* resolver); + +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // LETENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_WHITESPACE_TOKENIZER_OP_RESOLVER_H_ diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver_wrapper.cc b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver_wrapper.cc new file mode 100644 index 00000000..03d3ba89 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver_wrapper.cc @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "pybind11/pybind11.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_op_resolver.h" + +PYBIND11_MODULE(_pywrap_whitespace_tokenizer_op_resolver, m) { + m.doc() = "_pywrap_whitespace_tokenizer_op_resolver"; + m.def( + "AddWhitespaceTokenizerCustomOp", + [](uintptr_t resolver) { + tflite::ops::custom::AddWhitespaceTokenizerCustomOp( + reinterpret_cast<tflite::MutableOpResolver*>(resolver)); + }, + "Op registerer function for the tftext:WhitespaceTokenizer custom op."); +} diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.cc b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.cc new file mode 100644 index 00000000..4654e46c --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.cc @@ -0,0 +1,189 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer.h" + +#include <string> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace whitespace_tokenizer { +namespace test { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +} // namespace + +enum OutputType { PADDED, RAGGED }; + +class WhitespaceTokenizerModel : public SingleOpModel { + public: + WhitespaceTokenizerModel(OutputType output_type, + const std::vector<std::string>& input_values, + const std::vector<int>& input_shape) + : input_shape_(input_shape) { + input_ = AddInput(TensorType_STRING); + output_values_ = AddOutput(TensorType_STRING); + if (output_type == RAGGED) { + for (int i = 0; i < input_shape_.size(); ++i) { + output_row_splits_.push_back(AddOutput(TensorType_INT64)); + } + } + SetCustomOp("WhitespaceTokenizer", {}, Register_tftext_WhitespaceTokenizer); + + BuildInterpreter({input_shape}); + PopulateStringTensor(input_, input_values); + Invoke(); + } + + std::vector<int> GetValuesTensorShape() { + return GetTensorShape(output_values_); + } + + std::vector<std::string> ExtractValuesTensorVector() { + std::vector<std::string> r; + TfLiteTensor* tensor = interpreter_->tensor(output_values_); + int n = GetStringCount(tensor); + for (int i = 0; i < n; ++i) { + StringRef ref = GetString(tensor, i); + r.emplace_back(ref.str, ref.len); + } + return r; + } + + void CheckRowSplits(const std::vector<int>& token_counts) { + int size = 1; + for (int i = 0; i < input_shape_.size(); ++i) { + size *= input_shape_[i]; + EXPECT_THAT(GetTensorShape(output_row_splits_[i]), ElementsAre(size + 1)) + << "row_splits " << i << " has the wrong shape"; + + std::vector<int64_t> expected_values(size + 1); + if (i == input_shape_.size() - 1) { + ASSERT_EQ(token_counts.size(), size); + + int index = 0; + expected_values[0] = index; + for (int j = 0; j < size; ++j) { + index += token_counts[j]; + expected_values[j + 1] = index; + } + } else { + for (int j = 0; j <= size; ++j) { + expected_values[j] = j * input_shape_[i + 1]; + } + } + EXPECT_THAT(ExtractVector<int64_t>(output_row_splits_[i]), + ElementsAreArray(expected_values)) + << "row_splits " << i << " has an incorrect value/index"; + } + } + + private: + int input_; + std::vector<int> input_shape_; + int output_values_; + std::vector<int> output_row_splits_; +}; // namespace test + +TEST(WhitespaceTokenizerTest, SingleStringPaddedOutput) { + WhitespaceTokenizerModel m(PADDED, {"this is a test"}, {1}); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(1, 4)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("this", "is", "a", "test")); +} + +TEST(WhitespaceTokenizerTest, SingleStringRaggedOutput) { + WhitespaceTokenizerModel m(RAGGED, {"this is a test"}, {1}); + m.CheckRowSplits({4}); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("this", "is", "a", "test")); +} + +TEST(WhitespaceTokenizerTest, VectorPaddedOutput) { + WhitespaceTokenizerModel m(PADDED, + {"this is a test", // + "three token sentence", // + "many more tokens than that sentence"}, + {3}); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3, 6)); + EXPECT_THAT( + m.ExtractValuesTensorVector(), + ElementsAre("this", "is", "a", "test", "", "", // + "three", "token", "sentence", "", "", "", // + "many", "more", "tokens", "than", "that", "sentence")); +} + +TEST(WhitespaceTokenizerTest, VectorRaggedOutput) { + WhitespaceTokenizerModel m(RAGGED, + {"this is a test", // + "three token sentence", // + "many more tokens than that sentence"}, + {3}); + m.CheckRowSplits({4, 3, 6}); + EXPECT_THAT( + m.ExtractValuesTensorVector(), + ElementsAre("this", "is", "a", "test", // + "three", "token", "sentence", // + "many", "more", "tokens", "than", "that", "sentence")); +} + +TEST(WhitespaceTokenizerTest, MatrixPaddedOutput) { + WhitespaceTokenizerModel m(PADDED, + {"a b c", "d e f", // + "g h", "i j k l", // + "m", "n o p q r"}, + {3, 2}); + EXPECT_THAT(m.GetValuesTensorShape(), ElementsAre(3, 2, 5)); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("a", "b", "c", "", "", // + "d", "e", "f", "", "", // + "g", "h", "", "", "", // + "i", "j", "k", "l", "", // + "m", "", "", "", "", // + "n", "o", "p", "q", "r")); +} + +TEST(WhitespaceTokenizerTest, MatrixRAGGEDOutput) { + WhitespaceTokenizerModel m(RAGGED, + {"a b c", "d e f", // + "g h", "i j k l", // + "m", "n o p q r"}, + {3, 2}); + m.CheckRowSplits({3, 3, 2, 4, 1, 5}); + EXPECT_THAT(m.ExtractValuesTensorVector(), + ElementsAre("a", "b", "c", // + "d", "e", "f", // + "g", "h", // + "i", "j", "k", "l", // + "m", // + "n", "o", "p", "q", "r")); +} + +} // namespace test +} // namespace whitespace_tokenizer +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py new file mode 100644 index 00000000..b6a1a67d --- /dev/null +++ b/tensorflow_lite_support/custom_ops/kernel/whitespace_tokenizer_test.py @@ -0,0 +1,168 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Tests for tensorflow_lite_support.custom_ops.kernel.whitespace_tokenizer.""" + +import os +import sys +import timeit + +from absl import logging +from absl.testing import parameterized +import numpy as np +import tensorflow as tf +import tensorflow_text as tf_text +# pylint: disable=g-direct-tensorflow-import +from tensorflow.lite.python import interpreter as interpreter_wrapper +from tensorflow.python.platform import resource_loader + +# Force loaded shared object symbols to be globally visible. This is needed so +# that the interpreter_wrapper, in one .so file, can see the op resolver +# in a different .so file. Note that this may already be set by default. +# pylint: disable=g-import-not-at-top,g-bad-import-order,unused-import +if hasattr(sys, 'setdlopenflags') and hasattr(sys, 'getdlopenflags'): + sys.setdlopenflags(sys.getdlopenflags() | os.RTLD_GLOBAL) +from tensorflow_lite_support.custom_ops.kernel import _pywrap_whitespace_tokenizer_op_resolver + +TEST_CASES = [ + ['this is a test'], + ['extra spaces in here'], + ['a four token sentence', 'a five token sentence thing.'], + [['a multi dimensional test case', 'a b c d', 'e f g'], + ['h i j', 'k l m 2 3', 'n o p'], ['q r s 0 1', 't u v', 'w x y z']], +] + +INVOKES_FOR_SINGLE_OP_BENCHMARK = 1000 +INVOKES_FOR_FLEX_DELEGATE_BENCHMARK = 10 + + +@tf.function +def _call_whitespace_tokenizer_to_tensor(test_case): + tokenizer = tf_text.WhitespaceTokenizer() + return tokenizer.tokenize(test_case).to_tensor() + + +@tf.function +def _call_whitespace_tokenizer_to_ragged(test_case): + tokenizer = tf_text.WhitespaceTokenizer() + return tokenizer.tokenize(test_case) + + +class WhitespaceTokenizerTest(parameterized.TestCase): + + @parameterized.parameters([t] for t in TEST_CASES) + def testToTensorEquivalence(self, test_case): + tf_output = _call_whitespace_tokenizer_to_tensor(test_case) + + model_filename = resource_loader.get_path_to_datafile( + 'testdata/whitespace_tokenizer_to_tensor.tflite') + with open(model_filename, 'rb') as file: + model = file.read() + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=model, + custom_op_registerers=['AddWhitespaceTokenizerCustomOp']) + + np_test_case = np.array(test_case, dtype=np.str) + interpreter.resize_tensor_input(0, np_test_case.shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + np_test_case) + interpreter.invoke() + tflite_output = interpreter.get_tensor( + interpreter.get_output_details()[0]['index']) + + self.assertEqual(tf_output.numpy().tolist(), tflite_output.tolist()) + + @parameterized.parameters([t] for t in TEST_CASES) + def testToRaggedEquivalence(self, test_case): + tf_output = _call_whitespace_tokenizer_to_ragged(test_case) + + np_test_case = np.array(test_case, dtype=np.str) + rank = len(np_test_case.shape) + + model_filename = resource_loader.get_path_to_datafile( + 'testdata/whitespace_tokenizer_to_ragged_{}d_input.tflite'.format(rank)) + with open(model_filename, 'rb') as file: + model = file.read() + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=model, + custom_op_registerers=['AddWhitespaceTokenizerCustomOp']) + interpreter.resize_tensor_input(0, np_test_case.shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + np_test_case) + interpreter.invoke() + + # Traverse the nested row_splits/values of the ragged tensor. + for i in range(rank): + tflite_output_cur_row_splits = interpreter.get_tensor( + interpreter.get_output_details()[1 + i]['index']) + self.assertEqual(tf_output.row_splits.numpy().tolist(), + tflite_output_cur_row_splits.tolist()) + tf_output = tf_output.values + + tflite_output_values = interpreter.get_tensor( + interpreter.get_output_details()[0]['index']) + self.assertEqual(tf_output.numpy().tolist(), tflite_output_values.tolist()) + + def testSingleOpLatency(self): + model_filename = resource_loader.get_path_to_datafile( + 'testdata/whitespace_tokenizer_to_tensor.tflite') + with open(model_filename, 'rb') as file: + model = file.read() + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=model, + custom_op_registerers=['AddWhitespaceTokenizerCustomOp']) + + latency = 0.0 + for test_case in TEST_CASES: + np_test_case = np.array(test_case, dtype=np.str) + interpreter.resize_tensor_input(0, np_test_case.shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + np_test_case) + start_time = timeit.default_timer() + for _ in range(INVOKES_FOR_SINGLE_OP_BENCHMARK): + interpreter.invoke() + latency = latency + timeit.default_timer() - start_time + + latency = latency / (INVOKES_FOR_SINGLE_OP_BENCHMARK * len(TEST_CASES)) + logging.info('Latency: %fms', latency * 1000.0) + + def testFlexDelegateLatency(self): + model_filename = resource_loader.get_path_to_datafile( + 'testdata/whitespace_tokenizer_flex_delegate.tflite') + with open(model_filename, 'rb') as file: + model = file.read() + interpreter = interpreter_wrapper.Interpreter(model_content=model) + + latency = 0.0 + for test_case in TEST_CASES: + np_test_case = np.array(test_case, dtype=np.str) + interpreter.resize_tensor_input(0, np_test_case.shape) + interpreter.allocate_tensors() + interpreter.set_tensor(interpreter.get_input_details()[0]['index'], + np_test_case) + start_time = timeit.default_timer() + for _ in range(INVOKES_FOR_FLEX_DELEGATE_BENCHMARK): + interpreter.invoke() + latency = latency + timeit.default_timer() - start_time + + latency = latency / (INVOKES_FOR_FLEX_DELEGATE_BENCHMARK * len(TEST_CASES)) + logging.info('Latency: %fms', latency * 1000.0) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_lite_support/custom_ops/python/BUILD b/tensorflow_lite_support/custom_ops/python/BUILD new file mode 100644 index 00000000..82a8a6ec --- /dev/null +++ b/tensorflow_lite_support/custom_ops/python/BUILD @@ -0,0 +1,61 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +py_library( + name = "tflite_text_api", + srcs = ["tflite_text_api.py"], + deps = [ + # tensorflow dep, + # tensorflow_text dep, + ], +) + +py_library( + name = "sentencepiece_tokenizer", + srcs = ["sentencepiece_tokenizer.py"], + data = [ + "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:sentencepiece_detokenizer_op.so", + "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:sentencepiece_tokenizer_op.so", + ], + srcs_version = "PY3", + deps = [ + # tensorflow dep, + "//tensorflow_lite_support/custom_ops/kernel/sentencepiece/py:pywrap_model_converter", + ], +) + +py_test( + name = "sentencepiece_tokenizer_test", + srcs = ["sentencepiece_tokenizer_test.py"], + data = [ + "//tensorflow_lite_support/custom_ops/kernel/sentencepiece:testdata", + ], + python_version = "PY3", + deps = [ + ":sentencepiece_tokenizer", + # tensorflow dep, + # tensorflow_text dep, + "//tensorflow_lite_support/custom_ops/kernel/sentencepiece/py:pywrap_tflite_registerer", + "@absl_py//absl:app", + "@absl_py//absl/flags", + "@absl_py//absl/logging", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "ragged_tensor_to_tensor_test", + srcs = ["ragged_tensor_to_tensor_test.py"], + python_version = "PY3", + deps = [ + # tensorflow dep, + "//tensorflow_lite_support/custom_ops/kernel/ragged/py:pywrap_tflite_registerer", + "@absl_py//absl:app", + "@absl_py//absl/flags", + "@absl_py//absl/logging", + ], +) diff --git a/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py b/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py new file mode 100644 index 00000000..319131e0 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/python/ragged_tensor_to_tensor_test.py @@ -0,0 +1,57 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for ragged_tensor_to_tensor.""" + +import tensorflow as tf +from tensorflow.lite.python import interpreter as interpreter_wrapper # pylint: disable=g-direct-tensorflow-import + + +class RaggedTensorToTensorTest(tf.test.TestCase): + + def test_ragged_to_tensor(self): + + @tf.function + def ragged_tensor_function(): + ragged_tensor = tf.RaggedTensor.from_row_splits( + values=[ + 13, 36, 83, 131, 13, 36, 4, 3127, 152, 130, 30, 2424, 168, 1644, + 1524, 4, 3127, 152, 130, 30, 2424, 168, 1644, 636 + ], + row_splits=[0, 0, 6, 15, 24]) + return ragged_tensor.to_tensor() + + concrete_function = ragged_tensor_function.get_concrete_function() + + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_function]) + converter.allow_custom_ops = True + tflite_model = converter.convert() + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=tflite_model, + custom_op_registerers=["TFLite_RaggedTensorToTensorRegisterer"]) + interpreter.allocate_tensors() + interpreter.invoke() + output_details = interpreter.get_output_details() + expected_result_values = [[0, 0, 0, 0, 0, 0, 0, 0, 0], + [13, 36, 83, 131, 13, 36, 0, 0, 0], + [4, 3127, 152, 130, 30, 2424, 168, 1644, 1524], + [4, 3127, 152, 130, 30, 2424, 168, 1644, 636]] + self.assertAllEqual( + interpreter.get_tensor(output_details[0]["index"]), + expected_result_values) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py new file mode 100644 index 00000000..21efed56 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py @@ -0,0 +1,125 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Lint as: python3 +"""Python class that implements Sentencepiece tokenizer. + +It follows TF.text designers design. + +""" +import tensorflow.compat.v2 as tf # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops.ragged import ragged_tensor # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader +gen_sentencepiece_detokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_detokenizer_op.so')) +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader +gen_sentencepiece_tokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_tokenizer_op.so')) +from tensorflow_lite_support.custom_ops.kernel.sentencepiece.py import pywrap_model_converter as model_converter + + +class SentencepieceTokenizer: + """Sentencepiece tokenizer with tf.text interface.""" + + def __init__(self, model, reverse=False, add_bos=False, add_eos=False): + converted_model = model_converter.convert_sentencepiece_model(model) + converted_model_detokenizer = model_converter.convert_sentencepiece_model_for_decoder( + model) + # Use uint8 tensor as a buffer for the model to avoid any possible changes, + # for example truncation by '\0'. + self._converted_model = tf.constant(list(converted_model), dtype=tf.uint8) + self._converted_model_detokenizer = tf.constant( + list(converted_model_detokenizer), dtype=tf.uint8) + self._vocab_size = model_converter.get_vocabulary_size(converted_model) + self._reverse = reverse + self._add_bos = add_bos + self._add_eos = add_eos + + def tokenize(self, inputs): + """The main tokenization function.""" + input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(inputs) + if input_tensor.shape.ndims is None: + raise ValueError("Rank of input_tensor must be statically known.") + if ragged_tensor.is_ragged(input_tensor): + # Ensure that input has row_split_dtype is int32 + input_tensor = input_tensor.with_row_splits_dtype(tf.int32) + # Recursively process the values of the ragged tensor. + tokens = self.tokenize(input_tensor.flat_values) + return input_tensor.with_flat_values(tokens) + else: + if input_tensor.shape.ndims > 1: + # Convert the input tensor to ragged and process it. + return self.tokenize( + tf.RaggedTensor.from_tensor( + input_tensor, row_splits_dtype=tf.int32)) + elif input_tensor.shape.ndims == 0: + tokens = self.tokenize(tf.stack([input_tensor])) + return tokens.values + else: + # Our rank 1 tensor is the correct shape, so we can process it as + # normal. + (output_values, row_splits) = ( + gen_sentencepiece_tokenizer_op.tf_sentencepiece_tokenize_op( + self._converted_model, input_tensor, 0, 0, self._add_bos, + self._add_eos, self._reverse)) + tokens = tf.RaggedTensor.from_nested_row_splits( + flat_values=output_values, + nested_row_splits=[row_splits], + validate=False) + return tokens + + def detokenize(self, input): # pylint: disable=redefined-builtin + """Detokenizes tokens into preprocessed text. + + Args: + input: A `RaggedTensor` or `Tensor` with int32 encoded text with rank >= + 1. + + Returns: + A N-1 dimensional string Tensor or RaggedTensor of the detokenized text. + """ + input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input) + if input_tensor.shape.ndims is None: + raise ValueError("Rank of input_tensor must be statically known.") + if input_tensor.shape.ndims == 0: + raise ValueError("Rank of input_tensor must be at least 1.") + if ragged_tensor.is_ragged(input_tensor): + if input_tensor.flat_values.shape.ndims > 1: + # If the flat_values of our ragged tensor is multi-dimensional, we can + # process it separately and our output will have the same nested + # splits as our input. + tokens = self.detokenize(input_tensor.flat_values) + return input_tensor.with_flat_values(tokens) + elif input_tensor.ragged_rank > 1: + # Recursively process the values of the ragged tensor. + tokens = self.detokenize(input_tensor.values) + return input_tensor.with_values(tokens) + else: + return gen_sentencepiece_detokenizer_op.tf_sentencepiece_detokenize_op( + self._converted_model_detokenizer, input_tensor.flat_values, + input_tensor.row_splits) + else: + if input_tensor.shape.ndims > 1: + # Convert the input tensor to ragged and process it. + return self.detokenize( + tf.RaggedTensor.from_tensor( + input_tensor, row_splits_dtype=tf.int32)) + else: + tokens = self.detokenize(tf.stack([input_tensor])) + return tf.reshape(tokens, []) + + def vocab_size(self): + """Returns size of the vocabulary in Sentencepiece model.""" + return self._vocab_size diff --git a/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py new file mode 100644 index 00000000..3609b469 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer_test.py @@ -0,0 +1,251 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Lint as: python3 +"""Tests for sentencepiece_tokenizer.""" + +import os +import sys +import time + +from absl import flags +import numpy as np +import tensorflow.compat.v2 as tf # pylint: disable=g-direct-tensorflow-import +import tensorflow_text +# Force loaded shared object symbols to be globally visible. This is needed so +# that the interpreter_wrapper, in one .so file, can see the op resolver +# in a different .so file. Note that this may already be set by default. +# pylint: disable=g-import-not-at-top,g-bad-import-order,unused-import +if hasattr(sys, "setdlopenflags") and hasattr(sys, "getdlopenflags"): + sys.setdlopenflags(sys.getdlopenflags() | os.RTLD_GLOBAL) +from tensorflow.lite.python import interpreter as interpreter_wrapper # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.platform import resource_loader +from tensorflow_lite_support.custom_ops.python import sentencepiece_tokenizer +from tensorflow_lite_support.custom_ops.kernel.sentencepiece.py import pywrap_tflite_registerer + +FLAGS = flags.FLAGS + +SENTENCEPIECE_MODEL_FILE = ( + "../kernel/sentencepiece/testdata/sentencepiece.model") + + +def _GetSentencepieceModel(): + model_filename = resource_loader.get_path_to_datafile( + SENTENCEPIECE_MODEL_FILE) + with open(model_filename, "rb") as file: + model = file.read() + return model + + +class SentencepieceTokenizerTest(tf.test.TestCase): + + def setUp(self): + super(SentencepieceTokenizerTest, self).setUp() + self.sentencepiece_model = _GetSentencepieceModel() + + def test_tftext_sentencepiece_tokenizer(self): + """Check that the new tokenizer produces the same result that the tftext one.""" + tftext_sp = tensorflow_text.SentencepieceTokenizer(self.sentencepiece_model) + opt_sp = sentencepiece_tokenizer.SentencepieceTokenizer( + self.sentencepiece_model) + + input_text = [ + u" ", u"to be or not to be", u"ignored by length text1", + u"ignored by length text2" + ] + tftext_tokenized = tftext_sp.tokenize(input_text) + opt_tokenized = opt_sp.tokenize(input_text) + self.assertAllEqual(tftext_tokenized, opt_tokenized) + + def test_tftext_sentencepiece_detokenizer(self): + """Check that the new tokenizer produces the same result that the tftext one.""" + tftext_sp = tensorflow_text.SentencepieceTokenizer(self.sentencepiece_model) + opt_sp = sentencepiece_tokenizer.SentencepieceTokenizer( + self.sentencepiece_model) + + input_text = [ + u" ", u"to be or not to be", u"ignored by length text1", + u"ignored by length text2" + ] + tftext_tokenized = tftext_sp.tokenize(input_text) + + # Check detokenizer + tftext_detokenized = tftext_sp.detokenize(tftext_tokenized) + opt_detokenized = opt_sp.detokenize(tftext_tokenized) + self.assertAllEqual(tftext_detokenized, opt_detokenized) + + def test_tftext_sentencepiece_tokenizer_bos_eos(self): + """Check that the new tokenizer produces the same result that the tftext one with bos and eos.""" + tftext_sp = tensorflow_text.SentencepieceTokenizer( + self.sentencepiece_model, add_bos=True, add_eos=True) + opt_sp = sentencepiece_tokenizer.SentencepieceTokenizer( + self.sentencepiece_model, add_bos=True, add_eos=True) + + input_text = [ + u" ", u"to be or not to be", u"ignored by length text1", + u"ignored by length text2" + ] + tftext_tokenized = tftext_sp.tokenize(input_text) + opt_tokenized = opt_sp.tokenize(input_text) + self.assertAllEqual(tftext_tokenized, opt_tokenized) + + def test_tflite_opt_sentence_tokenizer(self): + """Check that can convert a Keras model to TFLite and it produces the same result for tokenization.""" + + class TokenizerLayer(tf.keras.layers.Layer): + + def __init__(self, sentencepiece_model, **kwargs): + super(TokenizerLayer, self).__init__(**kwargs) + self.sp = sentencepiece_tokenizer.SentencepieceTokenizer( + sentencepiece_model) + + def call(self, input_tensor, **kwargs): + return self.sp.tokenize(input_tensor).flat_values + + model = tf.keras.models.Sequential( + [TokenizerLayer(self.sentencepiece_model)]) + input_data = np.array([[ + u" ", u"to be or not to be", u"ignored by length text1", + u"ignored by length text2" + ]]) + tf_result = model.predict(input_data) + converter = tf.lite.TFLiteConverter.from_keras_model(model) + supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] + converter.target_spec.supported_ops = supported_ops + converter.allow_custom_ops = True + tflite_model = converter.convert() + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=tflite_model, + custom_op_registerers=["TFLite_SentencepieceTokenizerRegisterer"]) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + + interpreter.set_tensor(input_details[0]["index"], input_data) + interpreter.invoke() + output_details = interpreter.get_output_details() + expected_result = [ + 13, 36, 83, 131, 13, 36, 4, 3127, 152, 130, 30, 2424, 168, 1644, 1524, + 4, 3127, 152, 130, 30, 2424, 168, 1644, 636 + ] + self.assertAllEqual(tf_result, expected_result) + self.assertAllEqual( + interpreter.get_tensor(output_details[0]["index"]), expected_result) + + def test_tflite_opt_sentence_detokenizer(self): + """Check that can convert a Keras model to TFLite and it produces the same result for tokenization.""" + + class DeTokenizerLayer(tf.keras.layers.Layer): + + def __init__(self, sentencepiece_model, **kwargs): + super(DeTokenizerLayer, self).__init__(**kwargs) + self.sp = sentencepiece_tokenizer.SentencepieceTokenizer( + sentencepiece_model) + + def call(self, input_tensor, **kwargs): + return self.sp.detokenize(input_tensor) + + model = tf.keras.models.Sequential( + [DeTokenizerLayer(self.sentencepiece_model)]) + input_data = np.array([[ + 13, 36, 83, 131, 13, 36, 4, 3127, 152, 130, 30, 2424, 168, 1644, 1524, + 4, 3127, 152, 130, 30, 2424, 168, 1644, 636 + ]], + dtype=np.int32) + tf_result = model.predict(input_data) + converter = tf.lite.TFLiteConverter.from_keras_model(model) + supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] + converter.target_spec.supported_ops = supported_ops + converter.allow_custom_ops = True + tflite_model = converter.convert() + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=tflite_model, + custom_op_registerers=["TFLite_SentencepieceTokenizerRegisterer"]) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + + interpreter.set_tensor(input_details[0]["index"], input_data) + interpreter.invoke() + output_details = interpreter.get_output_details() + expected_result = [ + "to be or not to be ignored by length text1 ignored by length text2" + ] + self.assertAllEqual(tf_result, expected_result) + self.assertAllEqual( + interpreter.get_tensor(output_details[0]["index"]), expected_result) + + def test_tflite_opt_sentence_tokenizer_vocab_size(self): + """Check that can convert a Keras model to TFLite and it produces the same result for vocabulary size.""" + + class TokenizerLayer(tf.keras.layers.Layer): + + def __init__(self, sentencepiece_model, **kwargs): + super(TokenizerLayer, self).__init__(**kwargs) + self.sp = sentencepiece_tokenizer.SentencepieceTokenizer( + sentencepiece_model) + + def call(self, input_tensor, **kwargs): + return self.sp.vocab_size() + + model = tf.keras.models.Sequential( + [TokenizerLayer(self.sentencepiece_model)]) + input_data = np.array([[""]]) + tf_result = model.predict(input_data) + converter = tf.lite.TFLiteConverter.from_keras_model(model) + supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] + converter.target_spec.supported_ops = supported_ops + converter.allow_custom_ops = True + tflite_model = converter.convert() + interpreter = interpreter_wrapper.InterpreterWithCustomOps( + model_content=tflite_model, + custom_op_registerers=["TFLite_SentencepieceTokenizerRegisterer"]) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + interpreter.set_tensor(input_details[0]["index"], input_data) + interpreter.invoke() + output_details = interpreter.get_output_details() + expected_result = 4000 + self.assertEqual(tf_result, expected_result) + self.assertAllEqual( + interpreter.get_tensor(output_details[0]["index"]), expected_result) + + +class SentencepieceTokenizerBenchmark(tf.test.Benchmark): + + def benchmarkTokenizer(self): + sp_model = _GetSentencepieceModel() + test_text = [ + "This week we celebrate the casts and creatives who have come together" + " to bring us our favorite.", + "More Stacks products demonstrated commitment to excellent support.", + "Test, test, test." + ] + + tftext_sp = tensorflow_text.SentencepieceTokenizer(sp_model) + opt_sp = sentencepiece_tokenizer.SentencepieceTokenizer(sp_model) + iter_number = 1000 + start = time.time() + for _ in range(iter_number): + _ = opt_sp.tokenize(test_text) + self.report_benchmark( + iters=iter_number, wall_time=time.time() - start, name="opt") + start = time.time() + for _ in range(iter_number): + _ = tftext_sp.tokenize(test_text) + self.report_benchmark( + iters=iter_number, wall_time=time.time() - start, name="tf.text") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_lite_support/custom_ops/python/tflite_text_api.py b/tensorflow_lite_support/custom_ops/python/tflite_text_api.py new file mode 100644 index 00000000..1466df29 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/python/tflite_text_api.py @@ -0,0 +1,126 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Wrapped TF.Text friendly to Tensorflow Lite conversion.""" + +import tensorflow as tf +import tensorflow_text as tf_text + + +class WhitespaceTokenizer(tf_text.Tokenizer): + """TFLite friendly API for tensorflow_text.WhitspaceTokenizer.tokenize. + + The strings are split on ICU defined whitespace characters. These + whitespace characters are dropped. See more details in + https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WhitespaceTokenizer.md + + Does not currently support tokenize_with_offsets(). + """ + + def __init__(self): + super(WhitespaceTokenizer, self).__init__() + self._tokenizer = tf_text.WhitespaceTokenizer() + + def tokenize(self, input_tensor): + """Tokenize input strings. + + Args: + input_tensor: A `Tensor` of UTF-8 strings with rank 0, 1 or 2. + + Returns: + A `RaggedTensor` of tokenized text. The returned shape is the shape of the + input tensor with an added ragged dimension for tokens of each string. + """ + + @tf.function(experimental_implements='name: "tftext:WhitespaceTokenizer"') + def func(input_tensor): + return self._tokenizer.tokenize(input_tensor) + + return func(input_tensor) + + +def ngrams(data, + width, + axis=-1, + reduction_type=None, + string_separator=' ', + name=None): + """TFLite friendly API for tensorflow_text.ngrams. + + Creates a tensor of n-grams based data, a token tensor. See more details in + https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/ngrams.md + + Args: + data: The data to reduce. Must be convertible into a tf.Tensor or a + tf.RaggedTensor (in which case it will be deconstructed into its component + tf.Tensors). + width: The width of the ngram window. If there is not sufficient data to + fill out the ngram window, the resulting ngram will be empty. + axis: The axis to create ngrams along. Note that for string join reductions, + only axis '-1' is supported; for other reductions, any positive or + negative axis can be used. Should be a constant. + reduction_type: A member of the Reduction enum. Should be a constant. + Currently supports: + * `Reduction.STRING_JOIN`: Join strings in the window. Note that axis must + be -1 here. + string_separator: The separator string used for `Reduction.STRING_JOIN`. + Ignored otherwise. Must be a string constant, not a Tensor. + name: The op name. + + Returns: + A tensor of ngrams. If `data` is a ragged tensor, this will be a ragged + tensor. Otherwise it will be a plain tensor. + """ + + if reduction_type is not tf_text.Reduction.STRING_JOIN: + # TODO(b/162082752): Provide support for Reduction.SUM and Reduction.MEAN + raise tf.errors.InvalidArgumentError( + None, None, 'only Reduction.STRING_JOIN is currently supported') + + if reduction_type is tf_text.Reduction.STRING_JOIN and axis != -1: + raise tf.errors.InvalidArgumentError( + None, None, 'For Reduction.STRING_JOIN, axis must be -1') + + experimental_implements = [ + 'name: "tftext:Ngrams"', + 'attr { key: "width" value { i: %d } }' % width, + 'attr { key: "axis" value { i: %d } }' % axis, + 'attr { key: "reduction_type" value { s: "STRING_JOIN" } }', + 'attr { key: "string_separator" value { s: "%s" } }' % string_separator, + ] + experimental_implements = ' '.join(experimental_implements) + + if isinstance(data, tf.RaggedTensor): + + # Since `data` can not be converted directly into a Tensor, we define + # ragged_func() which takes a deconstructed tf.RaggedTensor + # (one flat_values tensor and N row_splits tensors), pass it the + # deconstructed version of `data`, and then immediately reconstruct it + # within ragged_func(). + @tf.function(experimental_implements=experimental_implements) + def ragged_func(values, *args): + ragged_tensor = tf.RaggedTensor.from_nested_row_splits( + flat_values=values, nested_row_splits=args) + return tf_text.ngrams(ragged_tensor, width, axis, reduction_type, + string_separator, name) + + return ragged_func(data.flat_values, *data.nested_row_splits) + + @tf.function(experimental_implements=experimental_implements) + def func(data): + return tf_text.ngrams(data, width, axis, reduction_type, string_separator, + name) + + return func(data) diff --git a/tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflite b/tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflite Binary files differnew file mode 100644 index 00000000..e841b964 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/testdata/sentencepiece_tokenizer_flex_op.tflite diff --git a/tensorflow_lite_support/custom_ops/tf_configure.sh b/tensorflow_lite_support/custom_ops/tf_configure.sh new file mode 100644 index 00000000..dbc96da7 --- /dev/null +++ b/tensorflow_lite_support/custom_ops/tf_configure.sh @@ -0,0 +1,60 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +function write_action_env_to_bazelrc() { + echo "build --action_env $1=\"$2\"" >> .bazelrc +} + +function is_linux() { + [[ "${PLATFORM}" == "linux" ]] +} + +function is_macos() { + [[ "${PLATFORM}" == "darwin" ]] +} + +function is_windows() { + # On windows, the shell script is actually running in msys + [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]] +} + +TF_CFLAGS=( $(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) +TF_LFLAGS="$(python3 -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')" +HEADER_DIR=${TF_CFLAGS:2} +if is_windows; then + SHARED_LIBRARY_DIR=${SHARED_LIBRARY_DIR//\\//} + SHARED_LIBRARY_NAME=${SHARED_LIBRARY_NAME//\\//} + HEADER_DIR=${HEADER_DIR//\\//} +fi +if is_windows; then + # Use pywrap_tensorflow instead of tensorflow_framework on Windows + SHARED_LIBRARY_DIR=${TF_CFLAGS:2:-7}"python" +else + SHARED_LIBRARY_DIR=${TF_LFLAGS:2} +fi +SHARED_LIBRARY_NAME=$(echo $TF_LFLAGS | rev | cut -d":" -f1 | rev) +if ! [[ $TF_LFLAGS =~ .*:.* ]]; then + if is_macos; then + SHARED_LIBRARY_NAME="libtensorflow_framework.dylib" + elif is_windows; then + # Use pywrap_tensorflow's import library on Windows. It is in the same dir as the dll/pyd. + SHARED_LIBRARY_NAME="_pywrap_tensorflow_internal.lib" + else + SHARED_LIBRARY_NAME="libtensorflow_framework.so" + fi +fi +write_action_env_to_bazelrc "TF_HEADER_DIR" ${HEADER_DIR} +write_action_env_to_bazelrc "TF_SHARED_LIBRARY_DIR" ${SHARED_LIBRARY_DIR} +write_action_env_to_bazelrc "TF_SHARED_LIBRARY_NAME" ${SHARED_LIBRARY_NAME} diff --git a/tensorflow_lite_support/custom_ops/tflite_inference_main.cc b/tensorflow_lite_support/custom_ops/tflite_inference_main.cc new file mode 100644 index 00000000..2819deea --- /dev/null +++ b/tensorflow_lite_support/custom_ops/tflite_inference_main.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This program runs the tflite model specified in --model with random inputs. +// For string type, the input is filled with a fixed string. + +#include <string> + +#include <glog/logging.h> +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/tools/command_line_flags.h" + +void FillRandomString(tflite::DynamicBuffer* buffer, + const TfLiteIntArray* dim_array, + const std::function<std::string()>& random_func) { + int num_elements = 1; + for (size_t i = 0; i < dim_array->size; i++) { + num_elements *= dim_array->data[i]; + } + for (int i = 0; i < num_elements; ++i) { + auto str = random_func(); + buffer->AddString(str.data(), str.length()); + } +} + +void RunWithRandomInputs(const std::string& filename) { + std::unique_ptr<tflite::FlatBufferModel> model = + tflite::FlatBufferModel::BuildFromFile(filename.c_str()); + + // Build the interpreter + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr<tflite::Interpreter> interpreter; + if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { + LOG(FATAL) << "Could not initialize interpreter for TFLite model."; + } + + // Resize input tensors, if desired. + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Could not allocate tensor."; + } + + // Fill the random data. + std::vector<std::vector<uint8_t>> sample; + for (int tensor_idx : interpreter->inputs()) { + auto tensor = interpreter->tensor(tensor_idx); + if (tensor->type == kTfLiteString) { + tflite::DynamicBuffer buffer; + FillRandomString(&buffer, tensor->dims, []() { + return "we're have some friends over saturday to hang out in the " + "yard"; + }); + buffer.WriteToTensor(tensor, /*new_shape=*/nullptr); + } else { + std::vector<uint8_t> data(tensor->bytes); + for (auto it = data.begin(); it != data.end(); ++it) { + *it = random(); + } + sample.push_back(data); + tensor->data.raw = reinterpret_cast<char*>(sample.rbegin()->data()); + } + } + + // Running inference. + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to run the model."; + } + + // Get the output. + for (int tensor_idx : interpreter->outputs()) { + auto tensor = interpreter->tensor(tensor_idx); + LOG(INFO) << "Output type: " << TfLiteTypeGetName(tensor->type); + } +} + +int main(int argc, char** argv) { + // Parse flags to get the filename. + std::string filename; + std::vector<tflite::Flag> flag_list{tflite::Flag::CreateFlag( + "model", &filename, "The tflite model to run sample inference.", + tflite::Flag::kRequired)}; + tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list); + tensorflow::port::InitMain(argv[0], &argc, &argv); + + // Run the model with random inputs. + RunWithRandomInputs(filename); + return 0; +} diff --git a/tensorflow_lite_support/examples/task/text/desktop/BUILD b/tensorflow_lite_support/examples/task/text/desktop/BUILD new file mode 100644 index 00000000..067d59cb --- /dev/null +++ b/tensorflow_lite_support/examples/task/text/desktop/BUILD @@ -0,0 +1,68 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:users", + ], + licenses = ["notice"], # Apache 2.0 +) + +# Example usage: +# bazel run -c opt \ +# tensorflow_lite_support/examples/task/text/desktop:bert_question_answerer_demo \ +# -- \ +# --model_path=/path/to/model.tflite \ +# --question="question to ask" \ +# --context="context for the question to ask" +cc_binary( + name = "bert_question_answerer_demo", + srcs = ["bert_question_answerer_demo.cc"], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) + +# Example usage: +# bazel run -c opt \ +# tensorflow_lite_support/examples/task/text/desktop:bert_nl_classifier_demo \ +# -- \ +# --model_path=/path/to/model.tflite \ +# --text="text to classify" +cc_binary( + name = "bert_nl_classifier_demo", + srcs = ["bert_nl_classifier_demo.cc"], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:category", + "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) + +# Example usage: +# bazel run -c opt \ +# tensorflow_lite_support/examples/task/text/desktop:nl_classifier_demo \ +# -- \ +# --model_path=/path/to/model.tflite \ +# --text="text to classify" \ +# --input_tensor_name="input_text" \ +# --output_score_tensor_name="probability" +cc_binary( + name = "nl_classifier_demo", + srcs = ["nl_classifier_demo.cc"], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:category", + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) diff --git a/tensorflow_lite_support/examples/task/text/desktop/README.md b/tensorflow_lite_support/examples/task/text/desktop/README.md new file mode 100644 index 00000000..859504fc --- /dev/null +++ b/tensorflow_lite_support/examples/task/text/desktop/README.md @@ -0,0 +1,134 @@ +# CLI Demos for C++ Text Task APIs + +This folder contains simple command-line tools for easily trying out the C++ +Text Task APIs. + +## Bert Question Answerer + +#### Prerequisites + +You will need: + +* a TFLite bert based question answerer model from model maker. +(e.g. [mobilebert][1] or [albert][2] available on TensorFlow Hub). + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1?lite-format=tflite' \ + -o /tmp/mobilebert.tflite + +# Run the classification tool: +bazel run -c opt \ + tensorflow_lite_support/examples/task/text/desktop:bert_question_answerer_demo -- \ + --model_path=/tmp/mobilebert.tflite \ + --question="Where is Amazon rainforest?" \ + --context="The Amazon rainforest, alternatively, the Amazon Jungle, also known in \ +English as Amazonia, is a moist broadleaf tropical rainforest in the Amazon \ +biome that covers most of the Amazon basin of South America. This basin \ +encompasses 7,000,000 km2 (2,700,000 sq mi), of which \ +5,500,000 km2 (2,100,000 sq mi) are covered by the rainforest. This region \ +includes territory belonging to nine nations." +``` + +#### Results + +In the console, you should get: + +``` +answer[0]: 'South America.' +logit: 1.84847, start_index: 39, end_index: 40 +answer[1]: 'most of the Amazon basin of South America.' +logit: 1.2921, start_index: 34, end_index: 40 +answer[2]: 'the Amazon basin of South America.' +logit: -0.0959535, start_index: 36, end_index: 40 +answer[3]: 'the Amazon biome that covers most of the Amazon basin of South America.' +logit: -0.498558, start_index: 28, end_index: 40 +answer[4]: 'Amazon basin of South America.' +logit: -0.774266, start_index: 37, end_index: 40 +``` + +## NLClassifier + +#### Prerequisites + +You will need: + +* a TFLite text classification model with certain format. +(e.g. [movie_review_model][3], a model to classify movie reviews), you'll need +to configure the input tensor and out tensor for the API, see the [doc][4] for +details. + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite' \ + -o /tmp/movie_review.tflite + +# Run the detection tool: +bazel run -c opt \ + tensorflow_lite_support/examples/task/text/desktop:nl_classifier_demo -- \ + --model_path=/tmp/movie_review.tflite \ + --text="What a waste of my time." \ + --input_tensor_name="input_text" \ + --output_score_tensor_name="probability" +``` + +#### Results + +In the console, you should get: + +``` +category[0]: 'Negative' : '0.81313' +category[1]: 'Positive' : '0.18687' +``` + +## BertNLClassifier + +#### Prerequisites + +TODO(b/163086702): Update the links to models with metadata attached. + +You will need: + +* a Bert based TFLite text classification model from model maker. (e.g. [movie_review_model][5] available on TensorFlow Hub). + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://url/to/bert/nl/classifier' \ + -o /tmp/bert_movie_review.tflite + +# Run the segmentation tool: +bazel run -c opt \ + tensorflow_lite_support/examples/task/text/desktop:bert_nl_classifier_demo -- \ + --model_path=/tmp/bert_movie_review.tflite \ + --text="it's a charming and often affecting journey" +``` + +#### Results + +In the console, you should get: + +``` +category[0]: 'negative' : '0.00006' +category[1]: 'positive' : '0.99994' +``` + +[1]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 +[2]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 +[3]: https://www.tensorflow.org/lite/models/text_classification/overview +[4]: https://github.com/tensorflow/tflite-support/blob/fe8b69002f5416900285dc69e2baa078c91bd994/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h#L55 +[5]: http://bert/nl/classifier/model diff --git a/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc b/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc new file mode 100644 index 00000000..15ea3bff --- /dev/null +++ b/tensorflow_lite_support/examples/task/text/desktop/bert_nl_classifier_demo.cc @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iostream> +#include <limits> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" + +ABSL_FLAG(std::string, model_path, "", + "Absolute path to the '.tflite' bert classification model."); +ABSL_FLAG(std::string, text, "", "Text to classify."); + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +absl::Status Classify() { + ASSIGN_OR_RETURN( + std::unique_ptr<BertNLClassifier> classifier, + BertNLClassifier::CreateFromFile(absl::GetFlag(FLAGS_model_path))); + + std::vector<core::Category> categories = + classifier->Classify(absl::GetFlag(FLAGS_text)); + + for (int i = 0; i < categories.size(); ++i) { + const core::Category& category = categories[i]; + std::cout << absl::StrFormat("category[%d]: '%s' : '%.5f'\n", i, + category.class_name, category.score); + } + + return absl::OkStatus(); +} + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_text).empty()) { + std::cerr << "Missing mandatory 'question' argument.\n"; + return 1; + } + + // Run classification. + absl::Status status = tflite::task::text::nlclassifier::Classify(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Classification failed: " << status.message() << "\n"; + return 1; + } +} diff --git a/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc b/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc new file mode 100644 index 00000000..743db71d --- /dev/null +++ b/tensorflow_lite_support/examples/task/text/desktop/bert_question_answerer_demo.cc @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iostream> +#include <limits> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h" + +ABSL_FLAG(std::string, model_path, "", + "Absolute path to the '.tflite' bert question answerer model."); +ABSL_FLAG(std::string, question, "", "Question to ask."); +ABSL_FLAG(std::string, context, "", + "Context the asked question is based upon."); + +namespace tflite { +namespace task { +namespace text { +namespace qa { + +absl::Status Answer() { + ASSIGN_OR_RETURN( + std::unique_ptr<QuestionAnswerer> answerer, + BertQuestionAnswerer::CreateFromFile(absl::GetFlag(FLAGS_model_path))); + + std::vector<QaAnswer> answers = answerer->Answer( + absl::GetFlag(FLAGS_context), absl::GetFlag(FLAGS_question)); + for (int i = 0; i < answers.size(); ++i) { + const QaAnswer& answer = answers[i]; + std::cout << absl::StrFormat( + "answer[%d]: '%s'\n logit: '%.5f, start_index: %d, end_index: %d\n", + i, answer.text, answer.pos.logit, answer.pos.start, answer.pos.end); + } + + return absl::OkStatus(); +} + +} // namespace qa +} // namespace text +} // namespace task +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_question).empty()) { + std::cerr << "Missing mandatory 'question' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_context).empty()) { + std::cerr << "Missing mandatory 'context' argument.\n"; + return 1; + } + // Run the answerer. + absl::Status status = tflite::task::text::qa::Answer(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Answer failed: " << status.message() << "\n"; + return 1; + } +} diff --git a/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc b/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc new file mode 100644 index 00000000..2e96ec63 --- /dev/null +++ b/tensorflow_lite_support/examples/task/text/desktop/nl_classifier_demo.cc @@ -0,0 +1,112 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iostream> +#include <limits> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/category.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" + +ABSL_FLAG(std::string, model_path, "", + "Absolute path to the '.tflite' classification model."); +ABSL_FLAG(std::string, text, "", "Text to classify."); +ABSL_FLAG(int, input_tensor_index, -1, "Input tensor index of the model."); +ABSL_FLAG(int, output_score_tensor_index, -1, + "Output score tensor index of the model."); +ABSL_FLAG(int, output_label_tensor_index, -1, + "Output label tensor index of the model."); +ABSL_FLAG(std::string, input_tensor_name, "", + "Input tensor name of the model."); +ABSL_FLAG(std::string, output_score_tensor_name, "", + "Output score tensor name of the model."); +ABSL_FLAG(std::string, output_label_tensor_name, "", + "Output label tensor name of the model."); + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +absl::Status Classify() { + NLClassifierOptions options{}; + if (absl::GetFlag(FLAGS_input_tensor_index) >= 0) { + options.input_tensor_index = absl::GetFlag(FLAGS_input_tensor_index); + } + if (absl::GetFlag(FLAGS_output_score_tensor_index) >= 0) { + options.output_score_tensor_index = + absl::GetFlag(FLAGS_output_score_tensor_index); + } + if (absl::GetFlag(FLAGS_output_label_tensor_index) >= 0) { + options.output_label_tensor_index = + absl::GetFlag(FLAGS_output_label_tensor_index); + } + if (!absl::GetFlag(FLAGS_input_tensor_name).empty()) { + options.input_tensor_name = absl::GetFlag(FLAGS_input_tensor_name); + } + if (!absl::GetFlag(FLAGS_output_score_tensor_name).empty()) { + options.output_score_tensor_name = + absl::GetFlag(FLAGS_output_score_tensor_name); + } + if (!absl::GetFlag(FLAGS_output_label_tensor_name).empty()) { + options.output_label_tensor_name = + absl::GetFlag(FLAGS_output_label_tensor_name); + } + + ASSIGN_OR_RETURN(std::unique_ptr<NLClassifier> classifier, + NLClassifier::CreateFromFileAndOptions( + absl::GetFlag(FLAGS_model_path), options)); + + std::vector<core::Category> categories = + classifier->Classify(absl::GetFlag(FLAGS_text)); + + for (int i = 0; i < categories.size(); ++i) { + const core::Category& category = categories[i]; + std::cout << absl::StrFormat("category[%d]: '%s' : '%.5f'\n", i, + category.class_name, category.score); + } + + return absl::OkStatus(); +} + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_text).empty()) { + std::cerr << "Missing mandatory 'question' argument.\n"; + return 1; + } + + // Run classification. + absl::Status status = tflite::task::text::nlclassifier::Classify(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Classification failed: " << status.message() << "\n"; + return 1; + } +} diff --git a/tensorflow_lite_support/examples/task/vision/desktop/BUILD b/tensorflow_lite_support/examples/task/vision/desktop/BUILD new file mode 100644 index 00000000..f61984ee --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/BUILD @@ -0,0 +1,68 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:users", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_binary( + name = "image_classifier_demo", + srcs = ["image_classifier_demo.cc"], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/cc/task/vision:image_classifier", + "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_binary( + name = "object_detector_demo", + srcs = ["object_detector_demo.cc"], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/cc/task/vision:object_detector", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_binary( + name = "image_segmenter_demo", + srcs = ["image_segmenter_demo.cc"], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/cc/task/vision:image_segmenter", + "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) diff --git a/tensorflow_lite_support/examples/task/vision/desktop/README.md b/tensorflow_lite_support/examples/task/vision/desktop/README.md new file mode 100644 index 00000000..73c6b637 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/README.md @@ -0,0 +1,180 @@ +# CLI Demos for C++ Vision Task APIs + +This folder contains simple command-line tools for easily trying out the C++ +Vision Task APIs. + +## Image Classifier + +#### Prerequisites + +You will need: + +* a TFLite image classification model (e.g. [aiy/vision/classifier/birds_V1][1], +a bird classification model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run classification on, e.g.: + +![sparrow](g3doc/sparrow.jpg) + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3?lite-format=tflite' \ + -o /tmp/aiy_vision_classifier_birds_V1_3.tflite + +# Run the classification tool: +bazel run -c opt \ + tensorflow_lite_support/examples/task/vision/desktop:image_classifier_demo -- \ + --model_path=/tmp/aiy_vision_classifier_birds_V1_3.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg \ + --max_results=3 +``` + +#### Results + +In the console, you should get: + +``` +Results: + Rank #0: + index : 671 + score : 0.91406 + class name : /m/01bwb9 + display name: Passer domesticus + Rank #1: + index : 670 + score : 0.00391 + class name : /m/01bwbt + display name: Passer montanus + Rank #2: + index : 495 + score : 0.00391 + class name : /m/0bwm6m + display name: Passer italiae +``` + +## Object Detector + +#### Prerequisites + +You will need: + +* a TFLite object detection model (e.g. [ssd_mobilenet_v1][2], a generic object +detection model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run detection on, e.g.: + +![dogs](g3doc/dogs.jpg) + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite' \ + -o /tmp/ssd_mobilenet_v1_1_metadata_1.tflite + +# Run the detection tool: +bazel run -c opt \ + tensorflow_lite_support/examples/task/vision/desktop:object_detector_demo -- \ + --model_path=/tmp/ssd_mobilenet_v1_1_metadata_1.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg \ + --output_png=/tmp/detection-output.png \ + --max_results=2 +``` + +#### Results + +In the console, you should get: + +``` +Results saved to: /tmp/detection-output.png +Results: + Detection #0 (red): + Box: (x: 355, y: 133, w: 190, h: 206) + Top-1 class: + index : 17 + score : 0.73828 + class name : dog + Detection #1 (green): + Box: (x: 103, y: 15, w: 138, h: 369) + Top-1 class: + index : 17 + score : 0.73047 + class name : dog +``` + +And `/tmp/detection-output.jpg` should contain: + +![detection-output](g3doc/detection-output.png) + +## Image Segmenter + +#### Prerequisites + +You will need: + +* a TFLite image segmentation model (e.g. [deeplab_v3][3], a generic +segmentation model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run segmentation on, e.g.: + +![plane](g3doc/plane.jpg) + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L 'https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1?lite-format=tflite' \ + -o /tmp/deeplabv3_1_metadata_1.tflite + +# Run the segmentation tool: +bazel run -c opt \ + tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo -- \ + --model_path=/tmp/deeplabv3_1_metadata_1.tflite \ + --image_path=\ +$(pwd)/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg \ + --output_mask_png=/tmp/segmentation-output.png +``` + +#### Results + +In the console, you should get: + +``` +Category mask saved to: /tmp/segmentation-output.png +Color Legend: + (r: 000, g: 000, b: 000): + index : 0 + class name : background + (r: 128, g: 000, b: 000): + index : 1 + class name : aeroplane + +# (omitting multiple lines for conciseness) ... + + (r: 128, g: 192, b: 000): + index : 19 + class name : train + (r: 000, g: 064, b: 128): + index : 20 + class name : tv +Tip: use a color picker on the output PNG file to inspect the output mask with +this legend. +``` + +And `/tmp/segmentation-output.jpg` should contain the segmentation mask: + +![segmentation-output](g3doc/segmentation-output.png) + +[1]: https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3 +[2]: https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2 +[3]: https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.png b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.png Binary files differnew file mode 100644 index 00000000..c8d56f40 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.png diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg Binary files differnew file mode 100644 index 00000000..9db4bee7 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg Binary files differnew file mode 100644 index 00000000..0edefa40 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png Binary files differnew file mode 100644 index 00000000..e871df33 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg Binary files differnew file mode 100644 index 00000000..25d213ea --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg diff --git a/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc new file mode 100644 index 00000000..cd97f4a2 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc @@ -0,0 +1,173 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/vision/desktop:image_classifier_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --image_path=/path/to/image.jpg + +#include <iostream> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/image_classifier.h" +#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +ABSL_FLAG(std::string, model_path, "", + "Absolute path to the '.tflite' image classifier model."); +ABSL_FLAG(std::string, image_path, "", + "Absolute path to the image to classify. The image must be RGB or " + "RGBA (grayscale is not supported). The image EXIF orientation " + "flag, if any, is NOT taken into account."); +ABSL_FLAG(int32, max_results, 5, + "Maximum number of classification results to display."); +ABSL_FLAG(float, score_threshold, 0, + "Classification results with a confidence score below this value are " + "rejected. If >= 0, overrides the score threshold(s) provided in the " + "TFLite Model Metadata. Ignored otherwise."); +ABSL_FLAG( + std::vector<std::string>, class_name_whitelist, {}, + "Comma-separated list of class names that acts as a whitelist. If " + "non-empty, classification results whose 'class_name' is not in this list " + "are filtered out. Mutually exclusive with 'class_name_blacklist'."); +ABSL_FLAG( + std::vector<std::string>, class_name_blacklist, {}, + "Comma-separated list of class names that acts as a blacklist. If " + "non-empty, classification results whose 'class_name' is in this list " + "are filtered out. Mutually exclusive with 'class_name_whitelist'."); + +namespace tflite { +namespace task { +namespace vision { + +ImageClassifierOptions BuildOptions() { + ImageClassifierOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + options.set_max_results(absl::GetFlag(FLAGS_max_results)); + if (absl::GetFlag(FLAGS_score_threshold) >= 0) { + options.set_score_threshold(absl::GetFlag(FLAGS_score_threshold)); + } + for (const std::string& class_name : + absl::GetFlag(FLAGS_class_name_whitelist)) { + options.add_class_name_whitelist(class_name); + } + for (const std::string& class_name : + absl::GetFlag(FLAGS_class_name_blacklist)) { + options.add_class_name_blacklist(class_name); + } + return options; +} + +void DisplayResult(const ClassificationResult& result) { + std::cout << "Results:\n"; + for (int head = 0; head < result.classifications_size(); ++head) { + if (result.classifications_size() > 1) { + std::cout << absl::StrFormat(" Head index %d:\n", head); + } + const Classifications& classifications = result.classifications(head); + for (int rank = 0; rank < classifications.classes_size(); ++rank) { + const Class& classification = classifications.classes(rank); + std::cout << absl::StrFormat(" Rank #%d:\n", rank); + std::cout << absl::StrFormat(" index : %d\n", + classification.index()); + std::cout << absl::StrFormat(" score : %.5f\n", + classification.score()); + if (classification.has_class_name()) { + std::cout << absl::StrFormat(" class name : %s\n", + classification.class_name()); + } + if (classification.has_display_name()) { + std::cout << absl::StrFormat(" display name: %s\n", + classification.display_name()); + } + } + } +} + +absl::Status Classify() { + // Build ImageClassifier. + const ImageClassifierOptions& options = BuildOptions(); + ASSIGN_OR_RETURN(std::unique_ptr<ImageClassifier> image_classifier, + ImageClassifier::CreateFromOptions(options)); + + // Load image in a FrameBuffer. + ASSIGN_OR_RETURN(ImageData image, + DecodeImageFromFile(absl::GetFlag(FLAGS_image_path))); + std::unique_ptr<FrameBuffer> frame_buffer; + if (image.channels == 3) { + frame_buffer = + CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height}); + } else if (image.channels == 4) { + frame_buffer = + CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height}); + } else { + return absl::InvalidArgumentError(absl::StrFormat( + "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d", + image.channels)); + } + + // Run classification and display results. + ASSIGN_OR_RETURN(ClassificationResult result, + image_classifier->Classify(*frame_buffer)); + DisplayResult(result); + + // Cleanup and return. + ImageDataFree(&image); + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_image_path).empty()) { + std::cerr << "Missing mandatory 'image_path' argument.\n"; + return 1; + } + if (!absl::GetFlag(FLAGS_class_name_whitelist).empty() && + !absl::GetFlag(FLAGS_class_name_blacklist).empty()) { + std::cerr << "'class_name_whitelist' and 'class_name_blacklist' arguments " + "are mutually exclusive.\n"; + return 1; + } + + // Run classification. + absl::Status status = tflite::task::vision::Classify(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Classification failed: " << status.message() << "\n"; + return 1; + } +} diff --git a/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc new file mode 100644 index 00000000..02af824f --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc @@ -0,0 +1,201 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --image_path=/path/to/image.jpg \ +// --output_mask_png=/path/to/output/mask.png + +#include <iostream> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +ABSL_FLAG(std::string, model_path, "", + "Absolute path to the '.tflite' image segmenter model."); +ABSL_FLAG(std::string, image_path, "", + "Absolute path to the image to segment. The image must be RGB or " + "RGBA (grayscale is not supported). The image EXIF orientation " + "flag, if any, is NOT taken into account."); +ABSL_FLAG(std::string, output_mask_png, "", + "Absolute path to the output category mask (confidence masks outputs " + "are not supported by this tool). Must have a '.png' extension."); + +namespace tflite { +namespace task { +namespace vision { + +ImageSegmenterOptions BuildOptions() { + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + // Confidence masks are not supported by this tool: output_type is set to + // CATEGORY_MASK by default. + return options; +} + +absl::Status EncodeMaskToPngFile(const SegmentationResult& result) { + if (result.segmentation_size() != 1) { + return absl::UnimplementedError( + "Image segmentation models with multiple output segmentations are not " + "supported by this tool."); + } + const Segmentation& segmentation = result.segmentation(0); + // Extract raw mask data as a uint8 pointer. + const uint8* raw_mask = + reinterpret_cast<const uint8*>(segmentation.category_mask().data()); + + // Create RgbImageData for the output mask. + uint8* pixel_data = static_cast<uint8*>( + malloc(segmentation.width() * segmentation.height() * 3 * sizeof(uint8))); + ImageData mask = {.pixel_data = pixel_data, + .width = segmentation.width(), + .height = segmentation.height(), + .channels = 3}; + + // Populate RgbImageData from the raw mask and ColoredLabel-s. + for (int i = 0; i < segmentation.width() * segmentation.height(); ++i) { + Segmentation::ColoredLabel colored_label = + segmentation.colored_labels(raw_mask[i]); + pixel_data[3 * i] = colored_label.r(); + pixel_data[3 * i + 1] = colored_label.g(); + pixel_data[3 * i + 2] = colored_label.b(); + } + + // Encode mask as PNG. + RETURN_IF_ERROR( + EncodeImageToPngFile(mask, absl::GetFlag(FLAGS_output_mask_png))); + std::cout << absl::StrFormat("Category mask saved to: %s\n", + absl::GetFlag(FLAGS_output_mask_png)); + + // Cleanup and return. + ImageDataFree(&mask); + return absl::OkStatus(); +} + +absl::Status DisplayColorLegend(const SegmentationResult& result) { + if (result.segmentation_size() != 1) { + return absl::UnimplementedError( + "Image segmentation models with multiple output segmentations are not " + "supported by this tool."); + } + const Segmentation& segmentation = result.segmentation(0); + const int num_labels = segmentation.colored_labels_size(); + + std::cout << "Color Legend:\n"; + for (int index = 0; index < num_labels; ++index) { + Segmentation::ColoredLabel colored_label = + segmentation.colored_labels(index); + std::cout << absl::StrFormat(" (r: %03d, g: %03d, b: %03d):\n", + colored_label.r(), colored_label.g(), + colored_label.b()); + std::cout << absl::StrFormat(" index : %d\n", index); + if (colored_label.has_class_name()) { + std::cout << absl::StrFormat(" class name : %s\n", + colored_label.class_name()); + } + if (colored_label.has_display_name()) { + std::cout << absl::StrFormat(" display name: %s\n", + colored_label.display_name()); + } + } + std::cout << "Tip: use a color picker on the output PNG file to inspect the " + "output mask with this legend.\n"; + + return absl::OkStatus(); +} + +absl::Status Segment() { + // Build ImageClassifier. + const ImageSegmenterOptions& options = BuildOptions(); + ASSIGN_OR_RETURN(std::unique_ptr<ImageSegmenter> image_segmenter, + ImageSegmenter::CreateFromOptions(options)); + + // Load image in a FrameBuffer. + ASSIGN_OR_RETURN(ImageData image, + DecodeImageFromFile(absl::GetFlag(FLAGS_image_path))); + std::unique_ptr<FrameBuffer> frame_buffer; + if (image.channels == 3) { + frame_buffer = + CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height}); + } else if (image.channels == 4) { + frame_buffer = + CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height}); + } else { + return absl::InvalidArgumentError(absl::StrFormat( + "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d", + image.channels)); + } + + // Run segmentation and save category mask. + ASSIGN_OR_RETURN(SegmentationResult result, + image_segmenter->Segment(*frame_buffer)); + RETURN_IF_ERROR(EncodeMaskToPngFile(result)); + + // Display the legend. + RETURN_IF_ERROR(DisplayColorLegend(result)); + + // Cleanup and return. + ImageDataFree(&image); + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_image_path).empty()) { + std::cerr << "Missing mandatory 'image_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_output_mask_png).empty()) { + std::cerr << "Missing mandatory 'output_mask_png' argument.\n"; + return 1; + } + if (!absl::EndsWithIgnoreCase(absl::GetFlag(FLAGS_output_mask_png), ".png")) { + std::cerr << "Argument 'output_mask_png' must end with '.png' or '.PNG'\n"; + return 1; + } + + // Run segmentation. + absl::Status status = tflite::task::vision::Segment(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Segmentation failed: " << status.message() << "\n"; + return 1; + } +} diff --git a/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc new file mode 100644 index 00000000..b7ab651a --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc @@ -0,0 +1,251 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/vision/desktop:object_detector_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --image_path=/path/to/image.jpg \ +// --output_png=/path/to/output.png + +#include <iostream> +#include <limits> + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/object_detector.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +ABSL_FLAG(std::string, model_path, "", + "Absolute path to the '.tflite' object detector model."); +ABSL_FLAG(std::string, image_path, "", + "Absolute path to the image to run detection on. The image must be " + "RGB or RGBA (grayscale is not supported). The image EXIF " + "orientation flag, if any, is NOT taken into account."); +ABSL_FLAG(std::string, output_png, "", + "Absolute path to a file where to draw the detection results on top " + "of the input image. Must have a '.png' extension."); +ABSL_FLAG(int32, max_results, 5, + "Maximum number of detection results to display."); +ABSL_FLAG( + float, score_threshold, std::numeric_limits<float>::lowest(), + "Detection results with a confidence score below this value are " + "rejected. If specified, overrides the score threshold(s) provided in the " + "TFLite Model Metadata. Ignored otherwise."); +ABSL_FLAG( + std::vector<std::string>, class_name_whitelist, {}, + "Comma-separated list of class names that acts as a whitelist. If " + "non-empty, detections results whose 'class_name' is not in this list " + "are filtered out. Mutually exclusive with 'class_name_blacklist'."); +ABSL_FLAG(std::vector<std::string>, class_name_blacklist, {}, + "Comma-separated list of class names that acts as a blacklist. If " + "non-empty, detections results whose 'class_name' is in this list " + "are filtered out. Mutually exclusive with 'class_name_whitelist'."); + +namespace tflite { +namespace task { +namespace vision { + +namespace { +// The line thickness (in pixels) for drawing the detection results. +constexpr int kLineThickness = 3; + +// The number of colors used for drawing the detection results. +constexpr int kColorMapSize = 10; + +// The names of the colors used for drawing the detection results. +constexpr std::array<absl::string_view, 10> kColorMapNames = { + "red", "green", "blue", "yellow", "fuschia", + "dark red", "dark green", "dark blue", "gray", "black"}; + +// The colors used for drawing the detection results as a flattened array of +// {R,G,B} components. +constexpr uint8 kColorMapComponents[30] = { + 255, 0, 0, 0, 255, 0, 0, 0, 255, 255, 255, 0, 255, 0, 255, + 128, 0, 0, 0, 128, 0, 0, 0, 128, 128, 128, 128, 0, 0, 0}; +} // namespace + +ObjectDetectorOptions BuildOptions() { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + options.set_max_results(absl::GetFlag(FLAGS_max_results)); + if (absl::GetFlag(FLAGS_score_threshold) > + std::numeric_limits<float>::lowest()) { + options.set_score_threshold(absl::GetFlag(FLAGS_score_threshold)); + } + for (const std::string& class_name : + absl::GetFlag(FLAGS_class_name_whitelist)) { + options.add_class_name_whitelist(class_name); + } + for (const std::string& class_name : + absl::GetFlag(FLAGS_class_name_blacklist)) { + options.add_class_name_blacklist(class_name); + } + return options; +} + +absl::Status EncodeResultToPngFile(const DetectionResult& result, + const ImageData* image) { + for (int index = 0; index < result.detections_size(); ++index) { + // Get bounding box as left, top, right, bottom. + const BoundingBox& box = result.detections(index).bounding_box(); + const int left = box.origin_x(); + const int top = box.origin_y(); + const int right = box.origin_x() + box.width(); + const int bottom = box.origin_y() + box.height(); + // Get color components. + const uint8 r = kColorMapComponents[3 * (index % kColorMapSize)]; + const uint8 g = kColorMapComponents[3 * (index % kColorMapSize) + 1]; + const uint8 b = kColorMapComponents[3 * (index % kColorMapSize) + 2]; + // Draw. Boxes might have coordinates outside of [0, w( x [0, h( so clamping + // is applied. + for (int y = std::max(0, top); y < std::min(image->height, bottom); ++y) { + for (int x = std::max(0, left); x < std::min(image->width, right); ++x) { + int pixel_index = image->channels * (image->width * y + x); + if (x < left + kLineThickness || x > right - kLineThickness || + y < top + kLineThickness || y > bottom - kLineThickness) { + image->pixel_data[pixel_index] = r; + image->pixel_data[pixel_index + 1] = g; + image->pixel_data[pixel_index + 2] = b; + } + } + } + } + // Encode to PNG and return. + RETURN_IF_ERROR( + EncodeImageToPngFile(*image, absl::GetFlag(FLAGS_output_png))); + std::cout << absl::StrFormat("Results saved to: %s\n", + absl::GetFlag(FLAGS_output_png)); + return absl::OkStatus(); +} + +void DisplayResult(const DetectionResult& result) { + std::cout << "Results:\n"; + for (int index = 0; index < result.detections_size(); ++index) { + std::cout << absl::StrFormat(" Detection #%d (%s):\n", index, + kColorMapNames[index % kColorMapSize]); + const Detection& detection = result.detections(index); + const BoundingBox& box = detection.bounding_box(); + std::cout << absl::StrFormat(" Box: (x: %d, y: %d, w: %d, h: %d)\n", + box.origin_x(), box.origin_y(), box.width(), + box.height()); + if (detection.classes_size() == 0) { + std::cout << " No top-1 class available"; + } else { + std::cout << " Top-1 class:\n"; + const Class& classification = detection.classes(0); + std::cout << absl::StrFormat(" index : %d\n", + classification.index()); + std::cout << absl::StrFormat(" score : %.5f\n", + classification.score()); + if (classification.has_class_name()) { + std::cout << absl::StrFormat(" class name : %s\n", + classification.class_name()); + } + if (classification.has_display_name()) { + std::cout << absl::StrFormat(" display name: %s\n", + classification.display_name()); + } + } + } +} + +absl::Status Detect() { + // Build ObjectDetector. + const ObjectDetectorOptions& options = BuildOptions(); + ASSIGN_OR_RETURN(std::unique_ptr<ObjectDetector> object_detector, + ObjectDetector::CreateFromOptions(options)); + + // Load image in a FrameBuffer. + ASSIGN_OR_RETURN(ImageData image, + DecodeImageFromFile(absl::GetFlag(FLAGS_image_path))); + std::unique_ptr<FrameBuffer> frame_buffer; + if (image.channels == 3) { + frame_buffer = + CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height}); + } else if (image.channels == 4) { + frame_buffer = + CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height}); + } else { + return absl::InvalidArgumentError(absl::StrFormat( + "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d", + image.channels)); + } + + // Run object detection and draw results on input image. + ASSIGN_OR_RETURN(DetectionResult result, + object_detector->Detect(*frame_buffer)); + RETURN_IF_ERROR(EncodeResultToPngFile(result, &image)); + + // Display results as text. + DisplayResult(result); + + // Cleanup and return. + ImageDataFree(&image); + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_image_path).empty()) { + std::cerr << "Missing mandatory 'image_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_output_png).empty()) { + std::cerr << "Missing mandatory 'output_png' argument.\n"; + return 1; + } + if (!absl::EndsWithIgnoreCase(absl::GetFlag(FLAGS_output_png), ".png")) { + std::cerr << "Argument 'output_png' must end with '.png' or '.PNG'\n"; + return 1; + } + if (!absl::GetFlag(FLAGS_class_name_whitelist).empty() && + !absl::GetFlag(FLAGS_class_name_blacklist).empty()) { + std::cerr << "'class_name_whitelist' and 'class_name_blacklist' arguments " + "are mutually exclusive.\n"; + return 1; + } + + // Run detection. + absl::Status status = tflite::task::vision::Detect(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Detection failed: " << status.message() << "\n"; + return 1; + } +} diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD b/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD new file mode 100644 index 00000000..9a837e2d --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD @@ -0,0 +1,22 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:users", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "image_utils", + srcs = ["image_utils.cc"], + hdrs = ["image_utils.h"], + deps = [ + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@stblib//:stb_image", + "@stblib//:stb_image_write", + ], +) diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc new file mode 100644 index 00000000..7c4604e9 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc @@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +#include <cstdlib> +#include <cstring> +#include <vector> + +// These need to be defined for stb_image.h and stb_image_write.h to include +// the actual implementations of image decoding/encoding functions. +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_WRITE_IMPLEMENTATION + +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "stb_image.h" +#include "stb_image_write.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/port/statusor.h" + +namespace tflite { +namespace task { +namespace vision { + +using ::tflite::support::StatusOr; + +StatusOr<ImageData> DecodeImageFromFile(const std::string& file_name) { + ImageData image_data; + image_data.pixel_data = stbi_load(file_name.c_str(), &image_data.width, + &image_data.height, &image_data.channels, + /*desired_channels=*/0); + if (image_data.pixel_data == nullptr) { + return absl::InternalError(absl::StrFormat( + "An error occurred while decoding image: %s", stbi_failure_reason())); + } + if (image_data.channels != 1 && image_data.channels != 3 && + image_data.channels != 4) { + stbi_image_free(image_data.pixel_data); + return absl::UnimplementedError( + absl::StrFormat("Expected image with 1 (grayscale), 3 (RGB) or 4 " + "(RGBA) channels, found %d", + image_data.channels)); + } + return image_data; +} + +absl::Status EncodeImageToPngFile(const ImageData& image_data, + const std::string& image_path) { + // Sanity check inputs. + if (image_data.width <= 0 || image_data.height <= 0) { + return absl::InvalidArgumentError( + absl::StrFormat("Expected positive image dimensions, found %d x %d.", + image_data.width, image_data.height)); + } + if (image_data.channels != 1 && image_data.channels != 3 && + image_data.channels != 4) { + return absl::UnimplementedError( + absl::StrFormat("Expected image data with 1 (grayscale), 3 (RGB) or 4 " + "(RGBA) channels, found %d", + image_data.channels)); + } + if (image_data.pixel_data == nullptr) { + return absl::InvalidArgumentError( + "Expected pixel data to be set, found nullptr."); + } + + if (stbi_write_png( + image_path.c_str(), image_data.width, image_data.height, + image_data.channels, image_data.pixel_data, + /*stride_in_bytes=*/image_data.width * image_data.channels) == 0) { + return absl::InternalError("An error occurred while encoding image."); + } + + return absl::OkStatus(); +} + +void ImageDataFree(ImageData* image) { stbi_image_free(image->pixel_data); } + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h new file mode 100644 index 00000000..38f62b60 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/statusor.h" + +namespace tflite { +namespace task { +namespace vision { + +// Image data with pixels stored as a row-major flattened array. +// Channels can be: +// 1 : grayscale +// 3 : RGB, interleaved +// 4 : RGBA, interleaved +struct ImageData { + uint8* pixel_data; + int width; + int height; + int channels; +}; + +// Decodes image file and returns the corresponding image if no error +// occurred. If decoding succeeded, the caller must manage deletion of the +// underlying pixel data using `ImageDataFree`. +// Supports a wide range of image formats, listed in `stb_image/stb_image.h`. +tflite::support::StatusOr<ImageData> DecodeImageFromFile( + const std::string& file_name); + +// Encodes the image provided as an ImageData as lossless PNG to the provided +// path. +absl::Status EncodeImageToPngFile(const ImageData& image_data, + const std::string& image_path); + +// Releases image pixel data memory. +void ImageDataFree(ImageData* image); + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ diff --git a/tensorflow_lite_support/ios/BUILD b/tensorflow_lite_support/ios/BUILD new file mode 100644 index 00000000..07b89515 --- /dev/null +++ b/tensorflow_lite_support/ios/BUILD @@ -0,0 +1,48 @@ +# TensorFlow Lite Task Library - Text + +load( + "@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", + "TFL_MINIMUM_OS_VERSION", + "tflite_ios_static_framework", +) +load( + "//tensorflow_lite_support/ios:ios.bzl", + "strip_c_api_include_path_prefix", +) + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +strip_c_api_include_path_prefix( + name = "strip_c_api_include_path", + hdr_labels = [ + "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier_c_api.h", + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api.h", + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api_common.h", + "//tensorflow_lite_support/cc/task/text/qa:bert_qa_c_api.h", + ], +) + +# This target builds a monolithic static framework for the TFLite Text API, +# which includes the TFLite runtime in it. +# +# bazel build -c opt --config=ios_fat //tensorflow_lite_support/ios:TensorFlowLiteTaskTextC_framework +tflite_ios_static_framework( + name = "TensorFlowLiteTaskTextC_framework", + hdrs = [ + ":bert_nl_classifier_c_api.h", + ":bert_qa_c_api.h", + ":nl_classifier_c_api.h", + ":nl_classifier_c_api_common.h", + ], + allowlist_symbols_file = ":allowlist_TensorFlowLiteTaskText.txt", + bundle_name = "TensorFlowLiteTaskTextC", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + deps = [ + "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier_c_api", + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api", + "//tensorflow_lite_support/cc/task/text/qa:bert_qa_c_api", + ], +) diff --git a/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template b/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template new file mode 100644 index 00000000..62c3f339 --- /dev/null +++ b/tensorflow_lite_support/ios/TensorFlowLiteTaskText.podspec.template @@ -0,0 +1,44 @@ +Pod::Spec.new do |s| + s.name = 'TensorFlowLiteTaskText' + s.version = '${TFLS_BUILD_VERSION}' + s.authors = 'Google Inc.' + s.license = { :type => 'Apache' } + s.homepage = 'https://github.com/tensorflow/tflite-support' + s.source = { :http => '${TFLS_DOWNLOAD_URL}' } + s.summary = 'TensorFlow Lite Task Library - Text' + s.description = 'The Natural Language APIs of the TFLite Task Library' + + s.ios.deployment_target = '9.0' + + s.module_name = 'TensorFlowLiteTaskText' + s.static_framework = true + + s.dependency 'GoogleToolboxForMac', '2.2.1' + + objc_dir = 'tensorflow_lite_support/ios/task/text/' + s.public_header_files = [ + objc_dir + 'apis/*.h', + objc_dir + '{nlclassifier,qa}/Sources/*.h' + ] + + cc_dir = 'tensorflow_lite_support/cc/task/text/' + s.source_files = [ + cc_dir + '{nlclassifier,qa}/*_c_api*.h', + objc_dir + 'apis/*.h', + objc_dir + '{nlclassifier,qa}/Sources/*.{h,m,mm}' + ] + s.module_map = objc_dir + 'apis/framework.modulemap' + s.pod_target_xcconfig = { + 'HEADER_SEARCH_PATHS' => + '"${PODS_TARGET_SRCROOT}" ' + + '"${PODS_TARGET_SRCROOT}/' + cc_dir + 'nlclassifier" ' + + '"${PODS_TARGET_SRCROOT}/' + cc_dir + 'qa" ' + + '"${PODS_TARGET_SRCROOT}/' + objc_dir + 'apis" ' + + '"${PODS_TARGET_SRCROOT}/' + objc_dir + 'nlclassifier/Sources" ' + + '"${PODS_TARGET_SRCROOT}/' + objc_dir + 'qa/Sources"', + 'VALID_ARCHS' => 'x86_64 armv7 arm64', + } + + s.library = 'c++' + s.vendored_frameworks = 'Frameworks/TensorFlowLiteTaskTextC.framework' +end diff --git a/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt b/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt new file mode 100644 index 00000000..3af5b0b1 --- /dev/null +++ b/tensorflow_lite_support/ios/allowlist_TensorFlowLiteTaskText.txt @@ -0,0 +1,3 @@ +_NLClassifier* +_BertNLClassifier* +_BertQuestionAnswerer* diff --git a/tensorflow_lite_support/ios/ios.bzl b/tensorflow_lite_support/ios/ios.bzl new file mode 100644 index 00000000..cb8c92ac --- /dev/null +++ b/tensorflow_lite_support/ios/ios.bzl @@ -0,0 +1,30 @@ +"""TensorFlow Lite Support Library Helper Rules for iOS""" + +# When the static framework is built with bazel, the all header files are moved +# to the "Headers" directory with no header path prefixes. This auxiliary rule +# is used for stripping the path prefix to the C API header files included by +# other C API header files. +def strip_c_api_include_path_prefix(name, hdr_labels, prefix = ""): + """Create modified header files with the common.h include path stripped out. + + Args: + name: The name to be used as a prefix to the generated genrules. + hdr_labels: List of header labels to strip out the include path. Each + label must end with a colon followed by the header file name. + prefix: Optional prefix path to prepend to the header inclusion path. + """ + + for hdr_label in hdr_labels: + hdr_filename = hdr_label.split(":")[-1] + hdr_basename = hdr_filename.split(".")[0] + + native.genrule( + name = "{}_{}".format(name, hdr_basename), + srcs = [hdr_label], + outs = [hdr_filename], + cmd = """ + sed 's|#include ".*/\\([^/]\\{{1,\\}}\\.h\\)"|#include "{}\\1"|'\ + "$(location {})"\ + > "$@" + """.format(prefix, hdr_label), + ) diff --git a/tensorflow_lite_support/ios/task/text/apis/TFLTaskText.h b/tensorflow_lite_support/ios/task/text/apis/TFLTaskText.h new file mode 100644 index 00000000..a42a4b38 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/apis/TFLTaskText.h @@ -0,0 +1,17 @@ +// Copyright 2020 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "TFLBertNLClassifier.h" +#import "TFLBertQuestionAnswerer.h" +#import "TFLNLClassifier.h" diff --git a/tensorflow_lite_support/ios/task/text/apis/framework.modulemap b/tensorflow_lite_support/ios/task/text/apis/framework.modulemap new file mode 100644 index 00000000..3e267620 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/apis/framework.modulemap @@ -0,0 +1,4 @@ +framework module TensorFlowLiteTaskText { + umbrella header "TFLTaskText.h" + export * +} diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD b/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD new file mode 100644 index 00000000..fb369e90 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/nlclassifier/BUILD @@ -0,0 +1,125 @@ +load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") +load("@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") +load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLBertNLClassifier", + srcs = ["Sources/TFLBertNLClassifier.m"], + hdrs = ["Sources/TFLBertNLClassifier.h"], + module_name = "TFLBertNLClassifier", + deps = [ + "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier_c_api", + "@google_toolbox_for_mac//:GTM_Defines", + ], +) + +swift_library( + name = "TFLBertNLClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["Tests/TFLBertNLClassifierTest.swift"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + ":TFLBertNLClassifier", + "//third_party/swift/xctest", + ], +) + +ios_unit_test( + name = "TFLBertNLClassifierSwiftTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLBertNLClassifierSwiftTestLibrary", + ], +) + +objc_library( + name = "TFLBertNLClassifierObjcTestLibrary", + testonly = 1, + srcs = ["Tests/TFLBertNLClassifierTest.m"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + ":TFLBertNLClassifier", + ], +) + +ios_unit_test( + name = "TFLBertNLClassifierObjcTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLBertNLClassifierObjcTestLibrary", + ], +) + +objc_library( + name = "TFLNLClassifier", + srcs = ["Sources/TFLNLClassifier.m"], + hdrs = ["Sources/TFLNLClassifier.h"], + module_name = "TFLNLClassifier", + deps = [ + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier_c_api", + "@google_toolbox_for_mac//:GTM_Defines", + ], +) + +swift_library( + name = "TFLNLClassifierSwiftTestLibrary", + testonly = 1, + srcs = ["Tests/TFLNLClassifierTest.swift"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + ":TFLNLClassifier", + "//third_party/swift/xctest", + ], +) + +ios_unit_test( + name = "TFLNLClassifierSwiftTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLNLClassifierSwiftTestLibrary", + ], +) + +objc_library( + name = "TFLNLClassifierObjcTestLibrary", + testonly = 1, + srcs = ["Tests/TFLNLClassifierTest.m"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:nl_classifier_models", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + ":TFLNLClassifier", + ], +) + +ios_unit_test( + name = "TFLNLClassifierObjcTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLNLClassifierObjcTestLibrary", + ], +) diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h new file mode 100644 index 00000000..ceed6fa8 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import <Foundation/Foundation.h> + +NS_ASSUME_NONNULL_BEGIN + +/** + * Classifier API for NLClassification tasks with Bert models, categorizes string into different + * classes. The API expects a Bert based TFLite model with metadata populated. + * + * The metadata should contain the following information: + * 1 input_process_unit for Wordpiece/Sentencepiece Tokenizer. + * 3 input tensors with names "ids", "mask" and "segment_ids". + * 1 output tensor of type float32[1, 2], with a optionally attached label file. If a label + * file is attached, the file should be a plain text file with one label per line, the number + * of labels should match the number of categories the model outputs. + */ +@interface TFLBertNLClassifier : NSObject + +/** + * Creates TFLBertNLClassifier from a model file. + * + * @param modelPath Path to the classification model. + * @return A TFLBertNLClassifier instance. + */ ++ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath + NS_SWIFT_NAME(bertNLClassifier(modelPath:)); + +/** + * Performs classification on a NSString input, returns <NSString *, NSNumber *> + * for categories and socres. + * + * @param text input text to the model. + * @return A NSDictionary of categorization results. + */ +- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text + NS_SWIFT_NAME(classify(text:)); +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m new file mode 100644 index 00000000..24c37b78 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.m @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h" +#import "GTMDefines.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier_c_api.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface TFLBertNLClassifier () +/** BertNLClassifier backed by C API */ +@property(nonatomic) BertNLClassifier *bertNLClassifier; +@end + +@implementation TFLBertNLClassifier + +- (void)dealloc { + BertNLClassifierDelete(_bertNLClassifier); +} + ++ (instancetype)bertNLClassifierWithModelPath:(NSString *)modelPath { + BertNLClassifier *classifier = BertNLClassifierFromFile(modelPath.UTF8String); + + _GTMDevAssert(classifier, @"Failed to create BertNLClassifier"); + return [[TFLBertNLClassifier alloc] initWithBertNLClassifier:classifier]; +} + +- (instancetype)initWithBertNLClassifier:(BertNLClassifier *)bertNLClassifier { + self = [super init]; + if (self) { + _bertNLClassifier = bertNLClassifier; + } + return self; +} + +- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text { + struct Categories *cCategories = BertNLClassifierClassify(_bertNLClassifier, text.UTF8String); + NSMutableDictionary<NSString *, NSNumber *> *ret = [NSMutableDictionary dictionary]; + for (int i = 0; i < cCategories->size; i++) { + struct Category cCategory = cCategories->categories[i]; + [ret setValue:[NSNumber numberWithDouble:cCategory.score] + forKey:[NSString stringWithUTF8String:cCategory.text]]; + } + NLClassifierCategoriesDelete(cCategories); + return ret; +} +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h new file mode 100644 index 00000000..ceb8d2ef --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import <Foundation/Foundation.h> + +NS_ASSUME_NONNULL_BEGIN + +/** + * Options to identify input and output tensors of the model. + */ +@interface TFLNLClassifierOptions : NSObject +@property(nonatomic) int inputTensorIndex; +@property(nonatomic) int outputScoreTensorIndex; +@property(nonatomic) int outputLabelTensorIndex; +@property(nonatomic) NSString *inputTensorName; +@property(nonatomic) NSString *outputScoreTensorName; +@property(nonatomic) NSString *outputLabelTensorName; +@end + +/** + * Classifier API for natural language classification tasks, categorizes string into different + * classes. + * + * The API expects a TFLite model with the following input/output tensor: + * + * Input tensor (kTfLiteString) + * input of the model, accepts a string. + * + * Output score tensor + * (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool) + * output scores for each class, if type is one of the Int types, dequantize it, if it + * is Bool type, convert the values to 0.0 and 1.0 respectively. + * + * can have an optional associated file in metadata for labels, the file should be a + * plain text file with one label per line, the number of labels should match the number + * of categories the model outputs. Output label tensor: optional (kTfLiteString) - + * output classname for each class, should be of the same length with scores. If this + * tensor is not present, the API uses score indices as classnames. - will be ignored if + * output score tensor already has an associated label file. + * + * Optional Output label tensor (kTfLiteString/kTfLiteInt32) + * output classname for each class, should be of the same length with scores. If this + * tensor is not present, the API uses score indices as classnames. + * + * will be ignored if output score tensor already has an associated labe file. + * + * By default the API tries to find the input/output tensors with default configurations in + * TFLNLClassifierOptions, with tensor name prioritized over tensor index. The option is + * configurable for different TFLite models. + */ +@interface TFLNLClassifier : NSObject + +/** + * Creates a TFLNLClassifier instance from TFLNLClassifierOptions. + * + * @param modelPath The file path to the tflite mdoel. + * @param options The TFLNLClassifierOptions to configure the model. + * + * @return A TFLNLClassifier instance. + */ ++ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath + options:(TFLNLClassifierOptions *)options + NS_SWIFT_NAME(nlClassifier(modelPath:options:)); + +/** + * Performs classification on a NSString input, returns <NSString *, NSNumber *> + * for categories and socres. + * + * @param text input text to the model. + * @return A NSDictionary of categorization results. + */ +- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text + NS_SWIFT_NAME(classify(text:)); +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m new file mode 100644 index 00000000..01a48188 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.m @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h" +#import "GTMDefines.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier_c_api_common.h" + +NS_ASSUME_NONNULL_BEGIN + +@implementation TFLNLClassifierOptions +@synthesize inputTensorIndex; +@synthesize outputScoreTensorIndex; +@synthesize outputLabelTensorIndex; +@synthesize inputTensorName; +@synthesize outputScoreTensorName; +@synthesize outputLabelTensorName; +@end + +@interface TFLNLClassifier () +/** NLClassifier backed by C API */ +@property(nonatomic) NLClassifier *nlClassifier; +@end + +@implementation TFLNLClassifier + +- (void)dealloc { + NLClassifierDelete(_nlClassifier); +} + ++ (instancetype)nlClassifierWithModelPath:(NSString *)modelPath + options:(TFLNLClassifierOptions *)options { + struct NLClassifierOptions cOptions = { + .input_tensor_index = options.inputTensorIndex, + .output_score_tensor_index = options.outputScoreTensorIndex, + .output_label_tensor_index = options.outputLabelTensorIndex, + .input_tensor_name = options.inputTensorName.UTF8String, + .output_score_tensor_name = + options.outputScoreTensorName.UTF8String, + .output_label_tensor_name = + options.outputLabelTensorName.UTF8String + }; + NLClassifier *classifier = NLClassifierFromFileAndOptions(modelPath.UTF8String, &cOptions); + _GTMDevAssert(classifier, @"Failed to create NLClassifier"); + return [[TFLNLClassifier alloc] initWithNLClassifier:classifier]; +} + +- (instancetype)initWithNLClassifier:(NLClassifier *)nlClassifier { + self = [super init]; + if (self) { + _nlClassifier = nlClassifier; + } + return self; +} + +- (NSDictionary<NSString *, NSNumber *> *)classifyWithText:(NSString *)text { + struct Categories *cCategories = NLClassifierClassify(_nlClassifier, text.UTF8String); + NSMutableDictionary<NSString *, NSNumber *> *ret = [NSMutableDictionary dictionary]; + for (int i = 0; i < cCategories->size; i++) { + struct Category cCategory = cCategories->categories[i]; + [ret setValue:[NSNumber numberWithDouble:cCategory.score] + forKey:[NSString stringWithUTF8String:cCategory.text]]; + } + NLClassifierCategoriesDelete(cCategories); + return ret; +} +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m new file mode 100644 index 00000000..9565bfb2 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.m @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h" + +#import <XCTest/XCTest.h> + +NS_ASSUME_NONNULL_BEGIN + +@interface TFLBertNLClassifierTest : XCTestCase +@property(nonatomic, nullable) NSString *bertModelPath; +@end + +@implementation TFLBertNLClassifierTest +#pragma mark - Tests + +- (void)setUp { + [super setUp]; + NSBundle *bundle = [NSBundle bundleForClass:[self class]]; + self.bertModelPath = [bundle pathForResource:@"test_model_nl_classifier_bert" + ofType:@"tflite"]; +} + +- (void)testClassifyPositiveResult { + TFLBertNLClassifier* bertNLClassifier = + [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath]; + + XCTAssertNotNil(bertNLClassifier); + + NSDictionary<NSString *, NSNumber *> * categories = + [bertNLClassifier classifyWithText:@"it's a charming and often affecting journey"]; + + XCTAssertGreaterThan([categories[@"positive"] doubleValue], + [categories[@"negative"] doubleValue]); +} + +- (void)testClassifyNegativeResult { + TFLBertNLClassifier* bertNLClassifier = + [TFLBertNLClassifier bertNLClassifierWithModelPath:self.bertModelPath]; + + XCTAssertNotNil(bertNLClassifier); + + NSDictionary<NSString *, NSNumber *> * categories = + [bertNLClassifier classifyWithText:@"unflinchingly bleak and desperate"]; + + XCTAssertGreaterThan([categories[@"negative"] doubleValue], + [categories[@"positive"] doubleValue]); +} +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift new file mode 100644 index 00000000..d331b04e --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLBertNLClassifierTest.swift @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import XCTest + +@testable import TFLBertNLClassifier + +class TFLBertNLClassifierTest: XCTestCase { + + static let bundle = Bundle(for: TFLBertNLClassifierTest.self) + static let bertModelPath = bundle.path(forResource: "test_model_nl_classifier_bert", ofType: "tflite")! + + func testClassifyPositiveResult() { + let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier( + modelPath: TFLBertNLClassifierTest.bertModelPath) + + XCTAssertNotNil(bertNLClassifier) + + let categories = bertNLClassifier.classify(text: "it's a charming and often affecting journey") + + XCTAssertGreaterThan(categories["positive"]!.doubleValue, categories["negative"]!.doubleValue) + } + + func testClassifyNegativeResult() { + let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier( + modelPath: TFLBertNLClassifierTest.bertModelPath) + + XCTAssertNotNil(bertNLClassifier) + + let categories = bertNLClassifier.classify(text: "unflinchingly bleak and desperate") + + XCTAssertGreaterThan(categories["negative"]!.doubleValue, categories["positive"]!.doubleValue) + } +} diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m new file mode 100644 index 00000000..40814ac6 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.m @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLNLClassifier.h" + +#import <XCTest/XCTest.h> + +NS_ASSUME_NONNULL_BEGIN + +@interface TFLNLClassifierTest : XCTestCase +@property(nonatomic, nullable) NSString *modelPath; +@property(nonatomic, nullable) TFLNLClassifierOptions *modelOptions; +@end + +@implementation TFLNLClassifierTest +#pragma mark - Tests + +- (void)setUp { + [super setUp]; + NSBundle *bundle = [NSBundle bundleForClass:[self class]]; + self.modelPath = [bundle pathForResource:@"test_model_nl_classifier_with_regex_tokenizer" + ofType:@"tflite"]; + self.modelOptions = [[TFLNLClassifierOptions alloc] init]; + [self.modelOptions setInputTensorName:@"input_text"]; + [self.modelOptions setOutputScoreTensorName:@"probability"]; +} + +- (void)testClassifyPositiveResult { + TFLNLClassifier *nlClassifier = [TFLNLClassifier nlClassifierWithModelPath:self.modelPath + options:self.modelOptions]; + + XCTAssertNotNil(nlClassifier); + + NSDictionary<NSString *, NSNumber *> *categories = [nlClassifier + classifyWithText:@"This is the best movie I’ve seen in recent years. Strongly recommend it!"]; + + XCTAssertGreaterThan([categories[@"Positive"] doubleValue], + [categories[@"Negative"] doubleValue]); +} + +- (void)testClassifyNegativeResult { + TFLNLClassifier *nlClassifier = [TFLNLClassifier nlClassifierWithModelPath:self.modelPath + options:self.modelOptions]; + + XCTAssertNotNil(nlClassifier); + + NSDictionary<NSString *, NSNumber *> *categories = + [nlClassifier classifyWithText:@"What a waste of my time."]; + + XCTAssertGreaterThan([categories[@"Negative"] doubleValue], + [categories[@"Positive"] doubleValue]); +} +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.swift b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.swift new file mode 100644 index 00000000..fb80e5da --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/nlclassifier/Tests/TFLNLClassifierTest.swift @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import XCTest + +@testable import TFLNLClassifier + +class TFLNLClassifierTest: XCTestCase { + + static let bundle = Bundle(for: TFLNLClassifierTest.self) + static let modelPath = bundle.path( + forResource: "test_model_nl_classifier_with_regex_tokenizer", + ofType: "tflite")! + + var modelOptions:TFLNLClassifierOptions!; + + override func setUp() { + modelOptions = TFLNLClassifierOptions() + modelOptions.inputTensorName = "input_text" + modelOptions.outputScoreTensorName = "probability" + } + + func testClassifyPositiveResult() { + let nlClassifier = TFLNLClassifier.nlClassifier( + modelPath: TFLNLClassifierTest.modelPath, + options: modelOptions) + + XCTAssertNotNil(nlClassifier) + + let categories = nlClassifier.classify( + text: "This is the best movie I’ve seen in recent years. Strongly recommend it!") + + XCTAssertGreaterThan(categories["Positive"]!.doubleValue, categories["Negative"]!.doubleValue) + } + + func testClassifyNegativeResult() { + let nlClassifier = TFLNLClassifier.nlClassifier( + modelPath: TFLNLClassifierTest.modelPath, + options: modelOptions) + + XCTAssertNotNil(nlClassifier) + + let categories = nlClassifier.classify(text: "What a waste of my time.") + + XCTAssertGreaterThan(categories["Negative"]!.doubleValue, categories["Positive"]!.doubleValue) + } +} diff --git a/tensorflow_lite_support/ios/task/text/qa/BUILD b/tensorflow_lite_support/ios/task/text/qa/BUILD new file mode 100644 index 00000000..7998a8e9 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/qa/BUILD @@ -0,0 +1,71 @@ +load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") +load("@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") +load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLBertQuestionAnswerer", + srcs = ["Sources/TFLBertQuestionAnswerer.m"], + hdrs = ["Sources/TFLBertQuestionAnswerer.h"], + module_name = "TFLBertQuestionAnswerer", + deps = [ + "//tensorflow_lite_support/cc/task/text/qa:bert_qa_c_api", + "@google_toolbox_for_mac//:GTM_Defines", + ], +) + +swift_library( + name = "TFLBertQuestionAnswererSwiftTestLibrary", + testonly = 1, + srcs = glob(["Tests/*.swift"]), + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:albert_model", + "//tensorflow_lite_support/cc/test/testdata/task/text:mobile_bert_model", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + ":TFLBertQuestionAnswerer", + "//third_party/swift/xctest", + ], +) + +ios_unit_test( + name = "TFLBertQuestionAnswererSwiftTest", + size = "large", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLBertQuestionAnswererSwiftTestLibrary", + ], +) + +objc_library( + name = "TFLBertQuestionAnswererObjcTestLibrary", + testonly = 1, + srcs = glob(["Tests/*.m"]), + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:albert_model", + "//tensorflow_lite_support/cc/test/testdata/task/text:mobile_bert_model", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + ":TFLBertQuestionAnswerer", + ], +) + +ios_unit_test( + name = "TFLBertQuestionAnswererObjcTest", + size = "large", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLBertQuestionAnswererObjcTestLibrary", + ], +) diff --git a/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h b/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h new file mode 100644 index 00000000..57b7c69c --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import <Foundation/Foundation.h> + +NS_ASSUME_NONNULL_BEGIN +/** + * Struct to represent the logit and offset of the answer related to context. + */ +struct TFLPos { + int start; + int end; + float logit; +}; + +/** + * Class for the Answer to BertQuestionAnswerer. + */ +@interface TFLQAAnswer : NSObject +@property(nonatomic) struct TFLPos pos; +@property(nonatomic) NSString* text; +@end + +/** + * BertQA task API, performs tokenization for models (BERT, Albert, etc.) in + * preprocess and returns most possible answers. + * + * In particular, the branch of BERT models use WordPiece tokenizer, and the + * branch of Albert models use SentencePiece tokenizer, respectively. + */ +@interface TFLBertQuestionAnswerer : NSObject + +/** + * Creates a BertQuestionAnswerer instance with an albert model or mobilebert + * model. The API expects a Bert based TFLite model with metadata containing + * the following information: + * input_process_units: for Wordpiece/Sentencepiece Tokenizer + * 3 input tensors with names "ids", "mask" and "segment_ids" + * 2 output tensors with names "end_logits" and "start_logits" + * Sample models: + * https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 + * https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 + * @param modelPath The file path to the tflite model. + * @return A BertQuestionAnswerer instance. + */ ++ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath + NS_SWIFT_NAME(questionAnswerer(modelPath:)); + +/** + * Answers question based on the context. Could be empty if no answer was found + * from the given context. + * + * @param context Context the question bases on. + * @param question Question to ask. + * + * @return A list of answers to the question, reversely sorted by the + * probability of each answer. + */ +- (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context + question:(NSString*)question + NS_SWIFT_NAME(answer(context:question:)); +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m b/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m new file mode 100644 index 00000000..fc1bd08b --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.m @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h" +#import "GTMDefines.h" +#include "tensorflow_lite_support/cc/task/text/qa/bert_qa_c_api.h" + +NS_ASSUME_NONNULL_BEGIN + +@implementation TFLQAAnswer +@synthesize pos; +@synthesize text; +@end + +@interface TFLBertQuestionAnswerer() +/** BertQuestionAnswerer backed by C API */ +@property(nonatomic) BertQuestionAnswerer *bertQuestionAnswerer; +@end + +@implementation TFLBertQuestionAnswerer + +- (void)dealloc { + BertQuestionAnswererDelete(_bertQuestionAnswerer); +} + ++ (instancetype)questionAnswererWithModelPath:(NSString *)modelPath { + BertQuestionAnswerer* bert_qa = BertQuestionAnswererFromFile(modelPath.UTF8String); + + _GTMDevAssert(bert_qa, @"Failed to create BertQuestionAnswerer"); + return [[TFLBertQuestionAnswerer alloc] + initWithBertQuestionAnswerer:bert_qa]; +} + +- (instancetype)initWithBertQuestionAnswerer:(BertQuestionAnswerer *)bertQuestionAnswerer { + self = [super init]; + if (self) { + _bertQuestionAnswerer = bertQuestionAnswerer; + } + return self; +} + +- (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question { + struct QaAnswers *cAnswers = + BertQuestionAnswererAnswer(_bertQuestionAnswerer, context.UTF8String, question.UTF8String); + NSMutableArray<TFLQAAnswer *> *ret = [NSMutableArray arrayWithCapacity:cAnswers->size]; + for (int i = 0; i < cAnswers->size; i++) { + struct QaAnswer cAnswer = cAnswers->answers[i]; + TFLQAAnswer *answer = [[TFLQAAnswer alloc] init]; + struct TFLPos pos = {.start = cAnswer.start, + .end = cAnswer.end, + .logit = cAnswer.logit}; + [answer setPos:pos]; + [answer setText:[NSString stringWithUTF8String:cAnswer.text]]; + [ret addObject:answer]; + } + BertQuestionAnswererQaAnswersDelete(cAnswers); + return ret; +} +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m b/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m new file mode 100644 index 00000000..90610630 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.m @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "tensorflow_lite_support/ios/task/text/qa/Sources/TFLBertQuestionAnswerer.h" + +#import <XCTest/XCTest.h> + +static NSString *const kContext = + @"The role of teacher is often formal and ongoing, carried out at a school " + "or other place of formal education. In many countries, a person who " + "wishes to become a teacher must first obtain specified professional " + "qualifications or credentials from a university or college. These " + "professional qualifications may include the study of pedagogy, the " + "science of teaching. Teachers, like other professionals, may have to " + "continue their education after they qualify, a process known as " + "continuing professional development. Teachers may use a lesson plan to " + "facilitate student learning, providing a course of study which is called " + "the curriculum."; +static NSString *const kQuestion = @"What is a course of study called?"; +static NSString *const kAnswer = @"the curriculum."; + +@interface TFLBertQuestionAnswererTest : XCTestCase +@property(nonatomic, nullable) NSString *mobileBertModelPath; +@property(nonatomic, nullable) NSString *albertModelPath; +@end + +@implementation TFLBertQuestionAnswererTest +#pragma mark - Tests + +- (void)setUp { + [super setUp]; + NSBundle *bundle = [NSBundle bundleForClass:[self class]]; + self.mobileBertModelPath = [bundle pathForResource:@"mobilebert_with_metadata" ofType:@"tflite"]; + self.albertModelPath = [bundle pathForResource:@"albert_with_metadata" ofType:@"tflite"]; +} + +- (void)testInitMobileBert { + TFLBertQuestionAnswerer* mobileBertAnswerer = + [TFLBertQuestionAnswerer questionAnswererWithModelPath:self.mobileBertModelPath]; + + XCTAssertNotNil(mobileBertAnswerer); + + NSArray<TFLQAAnswer*>* answers = + [mobileBertAnswerer answerWithContext:kContext question:kQuestion]; + + XCTAssertEqualObjects([answers[0] text], kAnswer); +} + +- (void)testInitAlbert { + TFLBertQuestionAnswerer* albertAnswerer = + [TFLBertQuestionAnswerer questionAnswererWithModelPath:self.albertModelPath]; + + XCTAssertNotNil(albertAnswerer); + + NSArray<TFLQAAnswer*>* answers = + [albertAnswerer answerWithContext:kContext question:kQuestion]; + + + XCTAssertEqualObjects([answers[0] text], kAnswer); +} +@end diff --git a/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.swift b/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.swift new file mode 100644 index 00000000..3f15cc52 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/qa/Tests/TFLBertQuestionAnswererTest.swift @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import XCTest + +@testable import TFLBertQuestionAnswerer + +class TFLBertQuestionAnswererTest: XCTestCase { + + static let bundle = Bundle(for: TFLBertQuestionAnswererTest.self) + static let mobileBertModelPath = bundle.path(forResource: "mobilebert_with_metadata", ofType: "tflite")! + + static let albertModelPath = bundle.path(forResource: "albert_with_metadata", ofType: "tflite")! + + static let context = """ + The role of teacher is often formal and ongoing, carried out at a school or other place of + formal education. In many countries, a person who wishes to become a teacher must first obtain + specified professional qualifications or credentials from a university or college. These + professional qualifications may include the study of pedagogy, the science of teaching. + Teachers, like other professionals, may have to continue their education after they qualify, + a process known as continuing professional development. Teachers may use a lesson plan to + facilitate student learning, providing a course of study which is called the curriculum. + """ + static let question = "What is a course of study called?" + static let answer = "the curriculum." + + func testInitMobileBert() { + let mobileBertAnswerer = TFLBertQuestionAnswerer.questionAnswerer( + modelPath: TFLBertQuestionAnswererTest.mobileBertModelPath) + + XCTAssertNotNil(mobileBertAnswerer) + + let answers = mobileBertAnswerer.answer( + context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question) + + XCTAssertNotNil(answers) + XCTAssertEqual(answers[0].text, TFLBertQuestionAnswererTest.answer) + } + + func testInitAlbert() { + let albertAnswerer = TFLBertQuestionAnswerer.questionAnswerer( + modelPath: TFLBertQuestionAnswererTest.albertModelPath) + + XCTAssertNotNil(albertAnswerer) + + let answers = albertAnswerer.answer( + context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question) + + XCTAssertNotNil(answers) + XCTAssertEqual(answers[0].text, TFLBertQuestionAnswererTest.answer) + } +} diff --git a/tensorflow_lite_support/ios/text/tokenizers/BUILD b/tensorflow_lite_support/ios/text/tokenizers/BUILD new file mode 100644 index 00000000..34ba9c6b --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/BUILD @@ -0,0 +1,106 @@ +load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") +load("@org_tensorflow//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") +load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLTokenizerUtil", + srcs = [ + "Sources/TFLTokenizerUtil.mm", + ], + hdrs = [ + "Sources/TFLTokenizerUtil.h", + ], + module_name = "TFLTokenizerUtil", + deps = [ + "//tensorflow_lite_support/cc/text/tokenizers:tokenizer", + "//tensorflow_lite_support/ios/utils:TFLStringUtil", + ], +) + +objc_library( + name = "TFLBertTokenizer", + srcs = [ + "Sources/TFLBertTokenizer.mm", + ], + hdrs = [ + "Sources/TFLBertTokenizer.h", + "Sources/TFLTokenizer.h", + ], + module_name = "TFLBertTokenizer", + deps = [ + ":TFLTokenizerUtil", + "//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer", + "//tensorflow_lite_support/ios/utils:TFLStringUtil", + ], +) + +swift_library( + name = "TFLBertTokenizerTestLibrary", + testonly = 1, + srcs = ["Tests/TFLBertTokenizerTest.swift"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:mobile_bert_model", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + ":TFLBertTokenizer", + "//third_party/swift/xctest", + ], +) + +ios_unit_test( + name = "TFLBertTokenizerTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLBertTokenizerTestLibrary", + ], +) + +objc_library( + name = "TFLSentencepieceTokenizer", + srcs = [ + "Sources/TFLSentencepieceTokenizer.mm", + ], + hdrs = [ + "Sources/TFLSentencepieceTokenizer.h", + "Sources/TFLTokenizer.h", + ], + module_name = "TFLSentencepieceTokenizer", + deps = [ + ":TFLTokenizerUtil", + "//tensorflow_lite_support/cc/text/tokenizers:sentencepiece_tokenizer", + "//tensorflow_lite_support/ios/utils:TFLStringUtil", + ], +) + +swift_library( + name = "TFLSentencepieceTokenizerTestLibrary", + testonly = 1, + srcs = ["Tests/TFLSentencepieceTokenizerTest.swift"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:albert_model", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + ":TFLSentencepieceTokenizer", + "//third_party/swift/xctest", + ], +) + +ios_unit_test( + name = "TFLSentencepieceTokenizerTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLSentencepieceTokenizerTestLibrary", + ], +) diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h new file mode 100644 index 00000000..aa692489 --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h" + +NS_ASSUME_NONNULL_BEGIN +/** + * Wordpiece Tokenizer implemenation. + */ +@interface TFLBertTokenizer : NSObject <TFLTokenizer> + +/** + * Default initializer is not available. + */ +- (instancetype)init NS_UNAVAILABLE; + +/** + * Initializes the tokenizer with the path to wordpiece vocabulary file. + */ +- (instancetype)initWithVocabPath:(NSString *)vocabPath NS_DESIGNATED_INITIALIZER; + +/** + * Initializes the tokenizer with a list of tokens. + */ +- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab NS_DESIGNATED_INITIALIZER; +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm new file mode 100644 index 00000000..949cef2b --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.mm @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLBertTokenizer.h" +#include "third_party/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h" +#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h" +#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h" + +NS_ASSUME_NONNULL_BEGIN +using BertTokenizerCPP = ::tflite::support::text::tokenizer::BertTokenizer; + +@implementation TFLBertTokenizer { + std::unique_ptr<BertTokenizerCPP> _bertTokenizer; +} + +- (instancetype)initWithVocabPath:(NSString *)vocabPath { + self = [super init]; + if (self) { + _bertTokenizer = absl::make_unique<BertTokenizerCPP>(MakeString(vocabPath)); + } + return self; +} + +- (instancetype)initWithVocab:(NSArray<NSString *> *)vocab { + self = [super init]; + if (self) { + std::vector<std::string> vocabCpp; + vocabCpp.reserve([vocab count]); + for (NSString *word in vocab) { + vocabCpp.emplace_back(MakeString(word)); + } + _bertTokenizer = absl::make_unique<BertTokenizerCPP>(vocabCpp); + } + return self; +} + +- (NSArray<NSString *> *)tokensFromInput:(NSString *)input { + return Tokenize(_bertTokenizer.get(), input); +} + +- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens { + return ConvertTokensToIds(_bertTokenizer.get(), tokens); +} + +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h new file mode 100644 index 00000000..eef3bf1e --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h" + +NS_ASSUME_NONNULL_BEGIN +/** + * Sentencepiece Tokenizer implemenation. + */ +@interface TFLSentencepieceTokenizer : NSObject <TFLTokenizer> + +/** + * Default initializer is not available. + */ +- (instancetype)init NS_UNAVAILABLE; + +/** + * Initializes the tokenizer with the path to sentencepiece model file. + */ +- (instancetype)initWithModelPath:(NSString *)modelPath; +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm new file mode 100644 index 00000000..1e21cee5 --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.mm @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLSentencepieceTokenizer.h" +#include "third_party/absl/memory/memory.h" +#include "third_party/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h" +#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h" +#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h" + +NS_ASSUME_NONNULL_BEGIN +using SentencepieceTokenizerCPP = ::tflite::support::text::tokenizer::SentencePieceTokenizer; + +@implementation TFLSentencepieceTokenizer { + std::unique_ptr<SentencepieceTokenizerCPP> _spTokenizer; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath { + self = [super init]; + if (self) { + _spTokenizer = absl::make_unique<SentencepieceTokenizerCPP>(MakeString(modelPath)); + } + return self; +} + +- (NSArray<NSString *> *)tokensFromInput:(NSString *)input { + return Tokenize(_spTokenizer.get(), input); +} + +- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens { + return ConvertTokensToIds(_spTokenizer.get(), tokens); +} + +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h new file mode 100644 index 00000000..ee0972f8 --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizer.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import <Foundation/Foundation.h> + +NS_ASSUME_NONNULL_BEGIN +/** + * Protocol for a Tokenizer used in model proprocessing. + */ +@protocol TFLTokenizer + +/** + * Performs tokenization on input text. + * @param input The input string to be tokenized. + * + * @return A list of tokens. + */ +- (NSArray<NSString *> *)tokensFromInput:(NSString *)input; + +/* + * Convert a list of tokens back to their coressponding IDs. + * @param tokens The tokens to be converted. + * + * @return A list of ids. + */ +- (NSArray<NSNumber *> *)idsFromTokens:(NSArray<NSString *> *)tokens; +@end +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h new file mode 100644 index 00000000..574b5553 --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import <Foundation/Foundation.h> +#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h" + +using ::tflite::support::text::tokenizer::Tokenizer; + +/** + * Invokes the cpp tokenizer's tokenize function and converts input/output to objc. + * + * @param tokenizer The cpp tokenizer pointer. + * @param input The input string to be tokenized. + * + * @return A list of tokens. + */ +NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input); + +/** + * Invokes the cpp tokenizer's convertTokensToIds function and converts input/output to objc. + * + * @param tokenizer The cpp tokenizer pointer. + * @param input The tokens to be converted. + * + * @return A list of ids. + */ +NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens); diff --git a/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm new file mode 100644 index 00000000..52180578 --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.mm @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "third_party/tensorflow_lite_support/ios/text/tokenizers/Sources/TFLTokenizerUtil.h" + +#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h" + +using ::tflite::support::text::tokenizer::TokenizerResult; + +NSArray<NSString *> *Tokenize(Tokenizer *tokenizer, NSString *input) { + TokenizerResult tokenize_result = tokenizer->Tokenize(MakeString(input)); + std::vector<std::string> subwords = tokenize_result.subwords; + NSMutableArray<NSString *> *ret = [NSMutableArray arrayWithCapacity:subwords.size()]; + for (int i = 0; i < subwords.size(); ++i) { + [ret addObject:MakeNSString(subwords[i])]; + } + return ret; +} + +NSArray<NSNumber *> *ConvertTokensToIds(Tokenizer *tokenizer, NSArray<NSString *> *tokens) { + NSMutableArray<NSNumber *> *ret = [NSMutableArray arrayWithCapacity:[tokens count]]; + for (NSString *token in tokens) { + std::string cc_token = MakeString(token); + const char *cToken = cc_token.c_str(); + int id; + tokenizer->LookupId(cToken, &id); + [ret addObject:[NSNumber numberWithInt:id]]; + } + return ret; +} diff --git a/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift b/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift new file mode 100644 index 00000000..e805f301 --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLBertTokenizerTest.swift @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import XCTest + +@testable import TFLBertTokenizer + +class TFLBertTokenizerTest: XCTestCase { + static let bundle = Bundle(for: TFLBertTokenizerTest.self) + static let mobileBertVocabPath = bundle.path(forResource: "vocab", ofType: "txt")! + + func testInitBertTokenizerFromPath() { + let bertTokenizer = TFLBertTokenizer(vocabPath: TFLBertTokenizerTest.mobileBertVocabPath) + + XCTAssertNotNil(bertTokenizer) + + let tokens = bertTokenizer.tokens(fromInput: "i'm questionansweraskask") + + XCTAssertEqual(tokens, ["i", "'", "m", "question", "##ans", "##wer", "##ask", "##ask"]) + + let ids = bertTokenizer.ids(fromTokens: tokens) + + XCTAssertEqual(ids, [1045, 1005, 1049, 3160, 6962, 13777, 19895, 19895]) + } + + func testInitBertTokenizerFromVocab() { + let bertTokenizer = TFLBertTokenizer(vocab: ["hell", "##o", "wor", "##ld", "there"]) + + XCTAssertNotNil(bertTokenizer) + + let tokens = bertTokenizer.tokens(fromInput: "hello there hello world") + + XCTAssertEqual(tokens, ["hell", "##o", "there", "hell", "##o", "wor", "##ld"]) + + let ids = bertTokenizer.ids(fromTokens: tokens) + + XCTAssertEqual(ids, [0, 1, 4, 0, 1, 2, 3]) + } +} diff --git a/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLSentencepieceTokenizerTest.swift b/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLSentencepieceTokenizerTest.swift new file mode 100644 index 00000000..c7c6d1e2 --- /dev/null +++ b/tensorflow_lite_support/ios/text/tokenizers/Tests/TFLSentencepieceTokenizerTest.swift @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +import XCTest + +@testable import TFLSentencepieceTokenizer + +class TFLSentencepieceTokenizerTest: XCTestCase { + static let bundle = Bundle(for: TFLSentencepieceTokenizerTest.self) + static let spModelPath = bundle.path(forResource: "30k-clean", ofType: "model")! + + func testInitSentenpieceTokenizerFromPath() { + let spTokenizer = TFLSentencepieceTokenizer( + modelPath: TFLSentencepieceTokenizerTest.spModelPath) + + XCTAssertNotNil(spTokenizer) + + let tokens = spTokenizer.tokens(fromInput: "good morning, i'm your teacher.\n") + + XCTAssertEqual(tokens, ["▁good", "▁morning", ",", "▁i", "'", "m", "▁your", "▁teacher", "."]) + + let ids = spTokenizer.ids(fromTokens: tokens) + + XCTAssertEqual(ids, [254, 959, 15, 31, 22, 79, 154, 2197, 9]) + } +} diff --git a/tensorflow_lite_support/ios/utils/BUILD b/tensorflow_lite_support/ios/utils/BUILD new file mode 100644 index 00000000..63f10915 --- /dev/null +++ b/tensorflow_lite_support/ios/utils/BUILD @@ -0,0 +1,15 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLStringUtil", + srcs = [ + "Sources/TFLStringUtil.mm", + ], + hdrs = [ + "Sources/TFLStringUtil.h", + ], + module_name = "TFLStringUtil", +) diff --git a/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h b/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h new file mode 100644 index 00000000..3ea091f5 --- /dev/null +++ b/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h @@ -0,0 +1,23 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import <Foundation/Foundation.h> + +#include <string> + +// Translates a NSString encoded in UTF-8 to a std::string. +std::string MakeString(NSString*); + +// Translates a std::string to the equivalent NSString by making a copy. +NSString* MakeNSString(const std::string&); diff --git a/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm b/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm new file mode 100644 index 00000000..6e9cf238 --- /dev/null +++ b/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.mm @@ -0,0 +1,23 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#import "third_party/tensorflow_lite_support/ios/utils/Sources/TFLStringUtil.h" + +std::string MakeString(NSString* str) { return std::string([str UTF8String]); } + +NSString* MakeNSString(const std::string& str) { + return [[NSString alloc] initWithBytes:const_cast<void*>(static_cast<const void*>(str.data())) + length:str.length() + encoding:NSUTF8StringEncoding]; +} diff --git a/tensorflow_lite_support/java/AndroidManifest.xml b/tensorflow_lite_support/java/AndroidManifest.xml new file mode 100644 index 00000000..14909296 --- /dev/null +++ b/tensorflow_lite_support/java/AndroidManifest.xml @@ -0,0 +1,5 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.support"> + <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> +</manifest> diff --git a/tensorflow_lite_support/java/BUILD b/tensorflow_lite_support/java/BUILD new file mode 100644 index 00000000..090b1add --- /dev/null +++ b/tensorflow_lite_support/java/BUILD @@ -0,0 +1,78 @@ +# Description: +# TensorFlow Lite Support API in Java. + +load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "AndroidManifest.xml", + "default_version_script.lds", + "debug_version_script.lds", +]) + +# Android Library target for TFLite Support Library. It depends on TensorFlow +# Lite runtime (tensorflow/lite/java:tensorflowlite). If you don't want to +# introduce the native library into dependencies, use +# "tensorflowlite_support_java" instead, which depends on +# tensorflow/lite/java:tensorflowlite_java. +android_library( + name = "tensorflowlite_support", + srcs = glob( + ["src/java/org/tensorflow/lite/support/**/*.java"], + ), + javacopts = JAVACOPTS, + manifest = "AndroidManifest.xml", + deps = [ + "@org_checkerframework_qual", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite", + ], +) + +android_library( + name = "tensorflowlite_support_java", + srcs = glob( + ["src/java/org/tensorflow/lite/support/**/*.java"], + ), + javacopts = JAVACOPTS, + manifest = "AndroidManifest.xml", + deps = [ + "@org_checkerframework_qual", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + ], +) + +# TODO(b/156482505): Remove this target. +alias( + name = "tensorflow-lite-support-nogpu", + actual = ":tensorflow-lite-support", +) + +# This alias matches the associated .aar library name output style. +alias( + name = "tensorflow-lite-support", + actual = ":tensorflowlite_support", +) + +java_library( + name = "tensorflowlite_support_precondition_lib", + srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"], + javacopts = JAVACOPTS, + deps = [ + "@org_checkerframework_qual", + ], +) + +android_library( + name = "tensorflowlite_support_precondition", + srcs = ["src/java/org/tensorflow/lite/support/common/SupportPreconditions.java"], + javacopts = JAVACOPTS, + manifest = "AndroidManifest.xml", + deps = [ + "@org_checkerframework_qual", + ], +) diff --git a/tensorflow_lite_support/java/README.md b/tensorflow_lite_support/java/README.md new file mode 100644 index 00000000..8d37bf8b --- /dev/null +++ b/tensorflow_lite_support/java/README.md @@ -0,0 +1,38 @@ +# TensorFlow Lite Support + +TensorFlow Lite Support contains a set of tools and libraries that help +developing ML with TFLite for mobile apps. See the [documentation on +tensorflow.org](https://www.tensorflow.org/lite/inference_with_metadata/overview) +for more information about all the efforts under TensorFlow Lite Support. + +This directory contains the Java code for the TensorFlow Lite SupportLibrary +and TensorFlow Lite Task Library. + +## TensorFlow Lite Android Support Library + +Mobile application developers typically interact with typed objects such as +bitmaps or primitives such as integers. However, the TensorFlow Lite Interpreter +that runs the on-device machine learning model uses tensors in the form of +ByteBuffer, which can be difficult to debug and manipulate. The TensorFlow Lite +Android Support Library is designed to help process the input and output of +TensorFlow Lite models, and make the TensorFlow Lite interpreter easier to use. + +We welcome feedback from the community as we develop this support library, +especially around: + +* Use-cases we should support including data types and operations +* Ease of use - does the APIs make sense to the community + +See the [documentation](https://www.tensorflow.org/lite/inference_with_metadata/lite_support) +for more instruction and examples. + + +## TensorFlow Lite Android Task Library +TensorFlow Lite Task Library provides optimized ready-to-use model interfaces +for popular machine learning tasks, such as image classification, question and +answer, etc. The model interfaces are specifically designed for each task to +achieve the best performance and usability. Task Library works cross-platform +and is supported on Java, C++, and Swift. + +See the [documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview) +for more instruction and examples. diff --git a/tensorflow_lite_support/java/debug_version_script.lds b/tensorflow_lite_support/java/debug_version_script.lds new file mode 100644 index 00000000..53553a42 --- /dev/null +++ b/tensorflow_lite_support/java/debug_version_script.lds @@ -0,0 +1,5 @@ +VERS_1.0 { + # Export everything for debug purpose. + global: + *; +}; diff --git a/tensorflow_lite_support/java/default_version_script.lds b/tensorflow_lite_support/java/default_version_script.lds new file mode 100644 index 00000000..46bbffe7 --- /dev/null +++ b/tensorflow_lite_support/java/default_version_script.lds @@ -0,0 +1,12 @@ +VERS_1.0 { + # Export JNI and native C symbols. + global: + Java_*; + JNI_OnLoad; + JNI_OnUnload; + TfLite*; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow_lite_support/java/jni/BUILD b/tensorflow_lite_support/java/jni/BUILD new file mode 100644 index 00000000..2c01b50d --- /dev/null +++ b/tensorflow_lite_support/java/jni/BUILD @@ -0,0 +1,48 @@ +package(default_visibility = ["//tensorflow_lite_support:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +# Helper target for exposing JNI headers across multiple platforms. +cc_library( + name = "jni", + hdrs = select({ + # The Android toolchain makes "jni.h" available in the include path. + # For non-Android toolchains, generate jni.h and jni_md.h. + "//tensorflow_lite_support:android": [], + "//conditions:default": [ + ":jni.h", + ":jni_md.h", + ], + }), + includes = select({ + "//tensorflow_lite_support:android": [], + "//conditions:default": ["."], + }), + visibility = ["//visibility:public"], +) + +# Silly rules to make +# #include <jni.h> +# in the source headers work +# (in combination with the "includes" attribute of the tf_cuda_library rule +# above. Not needed when using the Android toolchain). +# +# Inspired from: +# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD +# but hopefully there is a simpler alternative to this. +genrule( + name = "copy_jni_h", + srcs = ["@bazel_tools//tools/jdk:jni_header"], + outs = ["jni.h"], + cmd = "cp -f $< $@", +) + +genrule( + name = "copy_jni_md_h", + srcs = select({ + "//tensorflow_lite_support:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], + "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], + }), + outs = ["jni_md.h"], + cmd = "cp -f $< $@", +) diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java new file mode 100644 index 00000000..e83fd403 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/FileUtil.java @@ -0,0 +1,184 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** File I/O utilities. */ +public class FileUtil { + private FileUtil() {} + + /** + * Loads labels from the label file into a list of strings. + * + * <p>A legal label file is the plain text file whose contents are split into lines, and each line + * is an individual value. The file should be in assets of the context. + * + * @param context The context holds assets. + * @param filePath The path of the label file, relative with assets directory. + * @return a list of labels. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadLabels(@NonNull Context context, @NonNull String filePath) + throws IOException { + return loadLabels(context, filePath, Charset.defaultCharset()); + } + + /** + * Loads labels from the label file into a list of strings. + * + * <p>A legal label file is the plain text file whose contents are split into lines, and each line + * is an individual value. The empty lines will be ignored. The file should be in assets of the + * context. + * + * @param context The context holds assets. + * @param filePath The path of the label file, relative with assets directory. + * @param cs {@code Charset} to use when decoding content of label file. + * @return a list of labels. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadLabels( + @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException { + SupportPreconditions.checkNotNull(context, "Context cannot be null."); + SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); + try (InputStream inputStream = context.getAssets().open(filePath)) { + return loadLabels(inputStream, cs); + } + } + + /** + * Loads labels from an input stream of an opened label file. See details for label files in + * {@link FileUtil#loadLabels(Context, String)}. + * + * @param inputStream the input stream of an opened label file. + * @return a list of labels. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadLabels(@NonNull InputStream inputStream) throws IOException { + return loadLabels(inputStream, Charset.defaultCharset()); + } + + /** + * Loads labels from an input stream of an opened label file. See details for label files in + * {@link FileUtil#loadLabels(Context, String)}. + * + * @param inputStream the input stream of an opened label file. + * @param cs {@code Charset} to use when decoding content of label file. + * @return a list of labels. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadLabels(@NonNull InputStream inputStream, Charset cs) + throws IOException { + List<String> labels = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, cs))) { + String line; + while ((line = reader.readLine()) != null) { + if (line.trim().length() > 0) { + labels.add(line); + } + } + return labels; + } + } + + /** + * Loads a vocabulary file (a single-column text file) into a list of strings. + * + * <p>A vocabulary file is a single-column plain text file whose contents are split into lines, + * and each line is an individual value. The file should be in assets of the context. + * + * @param context The context holds assets. + * @param filePath The path of the vocabulary file, relative with assets directory. + * @return a list of vocabulary words. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadSingleColumnTextFile( + @NonNull Context context, @NonNull String filePath, Charset cs) throws IOException { + return loadLabels(context, filePath, cs); + } + + /** + * Loads vocabulary from an input stream of an opened vocabulary file (which is a single-column + * text file). See details for vocabulary files in {@link FileUtil#loadVocabularyFile(Context, + * String)}. + * + * @param inputStream the input stream of an opened vocabulary file. + * @return a list of vocabulary words. + * @throws IOException if error occurs to open or read the file. + */ + @NonNull + public static List<String> loadSingleColumnTextFile(@NonNull InputStream inputStream, Charset cs) + throws IOException { + return loadLabels(inputStream, cs); + } + + /** + * Loads a file from the asset folder through memory mapping. + * + * @param context Application context to access assets. + * @param filePath Asset path of the file. + * @return the loaded memory mapped file. + * @throws IOException if an I/O error occurs when loading the tflite model. + */ + @NonNull + public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath) + throws IOException { + SupportPreconditions.checkNotNull(context, "Context should not be null."); + SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); + try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + } + + /** + * Loads a binary file from the asset folder. + * + * @param context Application context to access assets. + * @param filePath Asset path of the file. + * @return the byte array for the binary file. + * @throws IOException if an I/O error occurs when loading file. + */ + @NonNull + public static byte[] loadByteFromFile(@NonNull Context context, @NonNull String filePath) + throws IOException { + ByteBuffer buffer = loadMappedFile(context, filePath); + byte[] byteArray = new byte[buffer.remaining()]; + buffer.get(byteArray); + return byteArray; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java new file mode 100644 index 00000000..38dfe881 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Operator.java @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common; + +/** + * The common interface for classes that carries an "apply" method, which converts T to another one. + * @param <T> The class which Operator handles. + */ +public interface Operator<T> { + + /** + * Applies an operation on a T object, returning a T object. + * + * <p>Note: The returned object could probably be the same one with given input, and given input + * could probably be changed. + */ + T apply(T x); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java new file mode 100644 index 00000000..07d7e2bd --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/Processor.java @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common; + +/** + * Processes T object with prepared {@link Operator<T>}. + */ +public interface Processor<T> { + T process(T input); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java new file mode 100644 index 00000000..ff0c6406 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SequentialProcessor.java @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** + * A processor base class that chains a serial of {@link Operator<T>} and executes them. + * + * <p>Typically, users could use its subclasses, e.g. {@link + * org.tensorflow.lite.support.image.ImageProcessor} rather than directly use this one. + * + * @param <T> The type that the Operator is handling. + */ +public class SequentialProcessor<T> implements Processor<T> { + + /** List of operators added to this {@link SequentialProcessor}. */ + protected final List<Operator<T>> operatorList; + /** + * The {@link Map} between the operator name and the corresponding op indexes in {@code + * operatorList}. An operator may be added multiple times into this {@link SequentialProcessor}. + */ + protected final Map<String, List<Integer>> operatorIndex; + + protected SequentialProcessor(Builder<T> builder) { + operatorList = builder.operatorList; + operatorIndex = Collections.unmodifiableMap(builder.operatorIndex); + } + + @Override + public T process(T x) { + for (Operator<T> op : operatorList) { + x = op.apply(x); + } + return x; + } + + /** The inner builder class to build a Sequential Processor. */ + protected static class Builder<T> { + + private final List<Operator<T>> operatorList; + private final Map<String, List<Integer>> operatorIndex; + + protected Builder() { + operatorList = new ArrayList<>(); + operatorIndex = new HashMap<>(); + } + + public Builder<T> add(@NonNull Operator<T> op) { + SupportPreconditions.checkNotNull(op, "Adding null Op is illegal."); + operatorList.add(op); + String operatorName = op.getClass().getName(); + if (!operatorIndex.containsKey(operatorName)) { + operatorIndex.put(operatorName, new ArrayList<Integer>()); + } + operatorIndex.get(operatorName).add(operatorList.size() - 1); + return this; + } + + public SequentialProcessor<T> build() { + return new SequentialProcessor<T>(this); + } + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java new file mode 100644 index 00000000..8620e13e --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/SupportPreconditions.java @@ -0,0 +1,184 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common; + +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Static error checking util methods. */ +public final class SupportPreconditions { + /** + * Ensures that an object reference passed as a parameter to the calling method is not null. + * + * @param reference an object reference + * @return the non-null reference that was validated + * @throws NullPointerException if {@code reference} is null + */ + public static <T extends Object> T checkNotNull(T reference) { + if (reference == null) { + throw new NullPointerException("The object reference is null."); + } + return reference; + } + + /** + * Ensures that an object reference passed as a parameter to the calling method is not null. + * + * @param reference an object reference + * @param errorMessage the exception message to use if the check fails; will be converted to a + * string using {@link String#valueOf(Object)} + * @return the non-null reference that was validated + * @throws NullPointerException if {@code reference} is null + */ + public static <T extends Object> T checkNotNull(T reference, @Nullable Object errorMessage) { + if (reference == null) { + throw new NullPointerException(String.valueOf(errorMessage)); + } + return reference; + } + + /** + * Ensures that the given String is not empty and not null. + * + * @param string the String to test + * @return the non-null non-empty String that was validated + * @throws IllegalArgumentException if {@code string} is null or empty + */ + public static String checkNotEmpty(String string) { + if (string == null || string.length() == 0) { + throw new IllegalArgumentException("Given String is empty or null."); + } + return string; + } + + /** + * Ensures that the given String is not empty and not null. + * + * @param string the String to test + * @param errorMessage the exception message to use if the check fails; will be converted to a + * string using {@link String#valueOf(Object)} + * @return the non-null non-empty String that was validated + * @throws IllegalArgumentException if {@code string} is null or empty + */ + public static String checkNotEmpty(String string, Object errorMessage) { + if (string == null || string.length() == 0) { + throw new IllegalArgumentException(String.valueOf(errorMessage)); + } + return string; + } + + /** + * Ensures the truth of an expression involving one or more parameters to the calling method. + * + * @param expression a boolean expression. + * @throws IllegalArgumentException if {@code expression} is false. + */ + public static void checkArgument(boolean expression) { + if (!expression) { + throw new IllegalArgumentException(); + } + } + + /** + * Ensures the truth of an expression involving one or more parameters to the calling method. + * + * @param expression a boolean expression. + * @param errorMessage the exception message to use if the check fails; will be converted to a + * string using {@link String#valueOf(Object)}. + * @throws IllegalArgumentException if {@code expression} is false. + */ + public static void checkArgument(boolean expression, @Nullable Object errorMessage) { + if (!expression) { + throw new IllegalArgumentException(String.valueOf(errorMessage)); + } + } + + /** + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. + * + * @param index a user-supplied index identifying an element of an array, list or string + * @param size the size of that array, list or string + * @return the value of {@code index} + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size} + * @throws IllegalArgumentException if {@code size} is negative + */ + public static int checkElementIndex(int index, int size) { + return checkElementIndex(index, size, "index"); + } + + /** + * Ensures that {@code index} specifies a valid <i>element</i> in an array, list or string of size + * {@code size}. An element index may range from zero, inclusive, to {@code size}, exclusive. + * + * @param index a user-supplied index identifying an element of an array, list or string + * @param size the size of that array, list or string + * @param desc the text to use to describe this index in an error message + * @return the value of {@code index} + * @throws IndexOutOfBoundsException if {@code index} is negative or is not less than {@code size} + * @throws IllegalArgumentException if {@code size} is negative + */ + public static int checkElementIndex(int index, int size, @Nullable String desc) { + // Carefully optimized for execution by hotspot (explanatory comment above) + if (index < 0 || index >= size) { + throw new IndexOutOfBoundsException(badElementIndex(index, size, desc)); + } + return index; + } + + /** + * Ensures the truth of an expression involving the state of the calling instance, but not + * involving any parameters to the calling method. + * + * @param expression a boolean expression + * @throws IllegalStateException if {@code expression} is false + * @see Verify#verify Verify.verify() + */ + public static void checkState(boolean expression) { + if (!expression) { + throw new IllegalStateException(); + } + } + + /** + * Ensures the truth of an expression involving the state of the calling instance, but not + * involving any parameters to the calling method. + * + * @param expression a boolean expression + * @param errorMessage the exception message to use if the check fails; will be converted to a + * string using {@link String#valueOf(Object)} + * @throws IllegalStateException if {@code expression} is false + * @see Verify#verify Verify.verify() + */ + public static void checkState(boolean expression, @Nullable Object errorMessage) { + if (!expression) { + throw new IllegalStateException(String.valueOf(errorMessage)); + } + } + + private static String badElementIndex(int index, int size, @Nullable String desc) { + if (index < 0) { + return String.format("%s (%s) must not be negative", desc, index); + } else if (size < 0) { + throw new IllegalArgumentException("negative size: " + size); + } else { // index >= size + return String.format("%s (%s) must be less than size (%s)", desc, index, size); + } + } + + private SupportPreconditions() { + throw new AssertionError("SupportPreconditions is Uninstantiable."); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java new file mode 100644 index 00000000..d1b7021d --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorOperator.java @@ -0,0 +1,27 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common; + +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Applies some operation on TensorBuffers. + */ +public interface TensorOperator extends Operator<TensorBuffer> { + /** @see Operator#apply(Object) . */ + @Override + TensorBuffer apply(TensorBuffer input); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java new file mode 100644 index 00000000..31531b2e --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/TensorProcessor.java @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common; + +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * TensorProcessor is a helper class for preprocessing and postprocessing tensors. It could + * transform a {@link TensorBuffer} to another by executing a chain of {@link TensorOperator}. + * + * <p>Example Usage: + * + * <pre> + * TensorProcessor processor = new TensorProcessor.Builder().add(new NormalizeOp(1, 2)).build(); + * TensorBuffer anotherTensorBuffer = processor.process(tensorBuffer); + * </pre> + * + * @see TensorProcessor.Builder to build a {@link TensorProcessor} instance. + * @see TensorProcessor#process(TensorBuffer) to apply the processor on a {@link TensorBuffer}. + */ +public class TensorProcessor extends SequentialProcessor<TensorBuffer> { + private TensorProcessor(Builder builder) { + super(builder); + } + + /** The Builder to create an {@link TensorProcessor}, which could be executed later. */ + public static class Builder extends SequentialProcessor.Builder<TensorBuffer> { + + /** + * Creates a Builder to build {@link TensorProcessor}. + * + * @see #add(TensorOperator) to add an Op. + * @see #build() to complete the building process and get a built Processor. + */ + public Builder() { + super(); + } + + /** + * Adds an {@link TensorOperator} into the Operator chain. + * + * @param op the Operator instance to be executed then. + */ + public TensorProcessor.Builder add(TensorOperator op) { + super.add(op); + return this; + } + + /** Completes the building process and gets the {@link TensorProcessor} instance. */ + @Override + public TensorProcessor build() { + return new TensorProcessor(this); + } + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java new file mode 100644 index 00000000..3355b185 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/CastOp.java @@ -0,0 +1,55 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common.ops; + +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Casts a {@link TensorBuffer} to a specified data type. */ +public class CastOp implements TensorOperator { + + private final DataType destinationType; + + /** + * Constructs a CastOp. + * + * <p>Note: For only converting type for a certain {@link TensorBuffer} on-the-fly rather than in + * a processor, please directly use {@link TensorBuffer#createFrom(TensorBuffer, DataType)}. + * + * <p>When this Op is executed, if the original {@link TensorBuffer} is already in {@code + * destinationType}, the original buffer will be directly returned. + * + * @param destinationType: The type of the casted {@link TensorBuffer}. + * @throws IllegalArgumentException if {@code destinationType} is neither {@link DataType#UINT8} + * nor {@link DataType#FLOAT32}. + */ + public CastOp(DataType destinationType) { + SupportPreconditions.checkArgument( + destinationType == DataType.UINT8 || destinationType == DataType.FLOAT32, + "Destination type " + destinationType + " is not supported."); + this.destinationType = destinationType; + } + + @Override + public TensorBuffer apply(TensorBuffer input) { + if (input.getDataType() == destinationType) { + return input; + } + return TensorBuffer.createFrom(input, destinationType); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java new file mode 100644 index 00000000..18817478 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/DequantizeOp.java @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common.ops; + +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Dequantizes a {@link TensorBuffer} with given {@code zeroPoint} and {@code scale}. + * + * <p>Note: The data type of output tensor is always {@code FLOAT32} except when the DequantizeOp is + * created effectively as an identity Op such as setting {@code zeroPoint} to 0 and {@code scale} to + * 1 (in this case, the output tensor is the same instance as input). + * + * <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link DequantizeOp} will be bypassed, + * which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful + * when passing in the quantization parameters that are extracted directly from the TFLite model + * flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read + * as 0. + */ +public class DequantizeOp extends NormalizeOp implements TensorOperator { + + public DequantizeOp(float zeroPoint, float scale) { + // Quantization: f = (q - z) * s + super(zeroPoint, 1 / scale); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java new file mode 100644 index 00000000..9db1388b --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/NormalizeOp.java @@ -0,0 +1,160 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common.ops; + +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; +import org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat; + +/** + * Normalizes a {@link TensorBuffer} with given mean and stddev: output = (input - mean) / stddev. + */ +public class NormalizeOp implements TensorOperator { + + // mean.length should always be equal to stddev.length and always >= 1. + private final float[] mean; + private final float[] stddev; + private final int numChannels; + private final boolean isIdentityOp; + + /** + * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which + * satisfies: + * + * <pre> + * output = (input - mean) / stddev + * </pre> + * + * <p>In the following two cases, reset {@code mean} to 0 and {@code stddev} to 1 to bypass the + * normalization. <br> + * 1. Both {@code mean} and {code stddev} are 0. <br> + * 2. {@code mean} is 0 and {stddev} is Infinity. + * + * <p>Note: If {@code mean} is set to 0 and {@code stddev} is set to 1, no computation will + * happen, and original input will be directly returned in execution. + * + * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at + * present, except when the input is a {@link DataType#UINT8} tensor, {@code mean} is set to 0 and + * {@code stddev} is set to 1, so that the original {@link DataType#UINT8} tensor is returned. + * + * @param mean the mean value to be subtracted first. + * @param stddev the standard deviation value to divide then. + * @throws IllegalArgumentException if {@code stddev} is zero. + */ + public NormalizeOp(float mean, float stddev) { + // Make exceptions to the cases that + // 1. Both mean and stddev are 0.0f. This may happen when reading the normalization parameters + // from a tensor which does not have the values populated in the metadata. The same situation + // may also happen to the quantization parameters. + // 2. mean is 0.0f and stddev is Infinity. This may happen when reading the quantization + // parameters from a tensor which does not have the values populated in the metadata, and then + // passing the parameters into the DequantizeOp. + // Bypass both of the two cases, by reseting stddev to 1.0f. + if (mean == 0.0f && (stddev == 0.0f || Float.isInfinite(stddev))) { + stddev = 1.0f; + } + + SupportPreconditions.checkArgument(stddev != 0.0f, "Stddev cannot be zero."); + boolean meansIsZeroAndDevsIs1 = false; + if (mean == 0.0f && stddev == 1.0f) { + meansIsZeroAndDevsIs1 = true; + } + + this.isIdentityOp = meansIsZeroAndDevsIs1; + this.mean = new float[] {mean}; + this.stddev = new float[] {stddev}; + this.numChannels = 1; + } + + /** + * Initializes a NormalizeOp. When being called, it creates a new {@link TensorBuffer}, which + * satisfies: + * + * <pre> + * // Pseudo code. [...][i] means a certain element whose channel id is i. + * output[...][i] = (input[...][i] - mean[i]) / stddev[i] + * </pre> + * + * <p>Note: If all values in {@code mean} are set to 0 and all {@code stddev} are set to 1, no + * computation will happen, and original input will be directly returned in execution. + * + * <p>Note: The returned {@link TensorBuffer} is always a {@link DataType#FLOAT32} tensor at + * present, except that the input is a {@link DataType#UINT8} tensor, all {@code mean} are set to + * 0 and all {@code stddev} are set to 1. + * + * @param mean the mean values to be subtracted first for each channel. + * @param stddev the standard deviation values to divide then for each channel. + * @throws IllegalArgumentException if any {@code stddev} is zero, or {@code mean} has different + * number of elements with {@code stddev}, or any of them is empty. + */ + public NormalizeOp(@NonNull float[] mean, @NonNull float[] stddev) { + SupportPreconditions.checkNotNull(mean, "Mean cannot be null"); + SupportPreconditions.checkNotNull(stddev, "Stddev cannot be null"); + SupportPreconditions.checkArgument( + mean.length == stddev.length, + "Per channel normalization requires same number of means and stddevs"); + SupportPreconditions.checkArgument(mean.length > 0, "Means and stddevs are empty."); + this.mean = mean.clone(); + this.stddev = stddev.clone(); + boolean allMeansAreZeroAndAllDevsAre1 = true; + this.numChannels = mean.length; + for (int i = 0; i < numChannels; i++) { + SupportPreconditions.checkArgument(this.stddev[i] != 0, "Stddev cannot be zero."); + if (this.stddev[i] != 1 || this.mean[i] != 0) { + allMeansAreZeroAndAllDevsAre1 = false; + } + } + this.isIdentityOp = allMeansAreZeroAndAllDevsAre1; + } + + /** + * Applies the defined normalization on given tensor and returns the result. + * + * <p>Note: {@code input} is possibly the same instance with the output. + * + * @param input input tensor. It may be the same instance with the output. + * @return output tensor. + */ + @Override + @NonNull + public TensorBuffer apply(@NonNull TensorBuffer input) { + if (isIdentityOp) { + return input; + } + int[] shape = input.getShape(); + SupportPreconditions.checkArgument( + numChannels == 1 || (shape.length != 0 && shape[shape.length - 1] == numChannels), + "Number of means (stddevs) is not same with number of channels (size of last axis)."); + // TODO(136750944): Eliminate the array copy here. + float[] values = input.getFloatArray(); + int j = 0; + for (int i = 0; i < values.length; i++) { + values[i] = (values[i] - mean[j]) / stddev[j]; + j = (j + 1) % numChannels; + } + TensorBuffer output; + if (input.isDynamic()) { + output = TensorBufferFloat.createDynamic(DataType.FLOAT32); + } else { + output = TensorBufferFloat.createFixedSize(shape, DataType.FLOAT32); + } + output.loadArray(values, shape); + return output; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java new file mode 100644 index 00000000..8b3e82ae --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/common/ops/QuantizeOp.java @@ -0,0 +1,41 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.common.ops; + +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Quantizes a {@link TensorBuffer} with given {@code zeroPoint} and {@code scale}. + * + * <p>Note: {@link QuantizeOp} does not cast output to UINT8, but only performs the quantization + * math on top of input. The data type of output tensor is always {@code FLOAT32} except that the Op + * is effectively an identity Op (in this case, the output tensor is the same instance as the + * input). To connect with quantized model, a {@link CastOp} is probably needed. + * + * <p>If both {@code zeroPoint} and {@code scale} are 0, the {@link QuantizeOp} will be bypassed, + * which is equivalent to setting {@code zeroPoint} to 0 and {@code scale} to 1. This can be useful + * when passing in the quantization parameters that are extracted directly from the TFLite model + * flatbuffer. If the tensor is not quantized, both {@code zeroPoint} and {@code scale} will be read + * as 0. + */ +public class QuantizeOp extends NormalizeOp implements TensorOperator { + + public QuantizeOp(float zeroPoint, float scale) { + // Quantization: f = (q - z) * s, i.e. q = f / s + z = (f - (-z * s)) / s + super(-zeroPoint * scale, scale); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java new file mode 100644 index 00000000..b9590bfd --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BitmapContainer.java @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Holds a {@link Bitmap} and converts it to other image formats as needed. */ +final class BitmapContainer implements ImageContainer { + + private final Bitmap bitmap; + + /** + * Creates a {@link BitmapContainer} object with ARGB_8888 {@link Bitmap}. + * + * @throws IllegalArgumentException if the bitmap configuration is not ARGB_8888 + */ + static BitmapContainer create(Bitmap bitmap) { + return new BitmapContainer(bitmap); + } + + private BitmapContainer(Bitmap bitmap) { + checkNotNull(bitmap, "Cannot load null bitmap."); + checkArgument( + bitmap.getConfig().equals(Config.ARGB_8888), "Only supports loading ARGB_8888 bitmaps."); + this.bitmap = bitmap; + } + + @Override + public BitmapContainer clone() { + return create(bitmap.copy(bitmap.getConfig(), bitmap.isMutable())); + } + + @Override + public Bitmap getBitmap() { + // Not making a defensive copy for performance considerations. During image processing, + // users may need to set and get the bitmap many times. + return bitmap; + } + + @Override + public TensorBuffer getTensorBuffer(DataType dataType) { + TensorBuffer buffer = TensorBuffer.createDynamic(dataType); + ImageConversions.convertBitmapToTensorBuffer(bitmap, buffer); + return buffer; + } + + @Override + public int getWidth() { + return bitmap.getWidth(); + } + + @Override + public int getHeight() { + return bitmap.getHeight(); + } + + @Override + public ColorSpaceType getColorSpaceType() { + return ColorSpaceType.fromBitmapConfig(bitmap.getConfig()); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java new file mode 100644 index 00000000..1a463305 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/BoundingBoxUtil.java @@ -0,0 +1,244 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; + +import android.graphics.RectF; +import java.nio.ByteBuffer; +import java.nio.FloatBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Helper class for converting values that represents bounding boxes into rectangles. + * + * <p>The class provides a static function to create bounding boxes as {@link RectF} from different + * types of configurations. + * + * <p>Generally, a bounding box could be represented by 4 float values, but the values could be + * interpreted in many ways. We now support 3 {@link Type} of configurations, and the order of + * elements in each type is configurable as well. + */ +public final class BoundingBoxUtil { + + /** Denotes how a bounding box is represented. */ + public enum Type { + /** + * Represents the bounding box by using the combination of boundaries, {left, top, right, + * bottom}. The default order is {left, top, right, bottom}. Other orders can be indicated by an + * index array. + */ + BOUNDARIES, + /** + * Represents the bounding box by using the upper_left corner, width and height. The default + * order is {upper_left_x, upper_left_y, width, height}. Other orders can be indicated by an + * index array. + */ + UPPER_LEFT, + /** + * Represents the bounding box by using the center of the box, width and height. The default + * order is {center_x, center_y, width, height}. Other orders can be indicated by an index + * array. + */ + CENTER, + } + + /** Denotes if the coordinates are actual pixels or relative ratios. */ + public enum CoordinateType { + /** The coordinates are relative ratios in range [0, 1]. */ + RATIO, + /** The coordinates are actual pixel values. */ + PIXEL + } + + /** + * Creates a list of bounding boxes from a {@link TensorBuffer} which represents bounding boxes. + * + * @param tensor holds the data representing some boxes. + * @param valueIndex denotes the order of the elements defined in each bounding box type. An empty + * index array represent the default order of each bounding box type. For example, to denote + * the default order of BOUNDARIES, {left, top, right, bottom}, the index should be {0, 1, 2, + * 3}. To denote the order {left, right, top, bottom}, the order should be {0, 2, 1, 3}. + * <p>The index array can be applied to all bounding box types to adjust the order of their + * corresponding underlying elements. + * @param boundingBoxAxis specifies the index of the dimension that represents bounding box. The + * size of that dimension is required to be 4. Index here starts from 0. For example, if the + * tensor has shape 4x10, the axis for bounding boxes is likely to be 0. Negative axis is also + * supported: -1 gives the last axis and -2 gives the second, .etc. theFor shape 10x4, the + * axis is likely to be 1 (or -1, equivalently). + * @param type defines how values should be converted into boxes. See {@link Type} + * @param coordinateType defines how values are interpreted to coordinates. See {@link + * CoordinateType} + * @param height the height of the image which the boxes belong to. Only has effects when {@code + * coordinateType} is {@link CoordinateType#RATIO} + * @param width the width of the image which the boxes belong to. Only has effects when {@code + * coordinateType} is {@link CoordinateType#RATIO} + * @return A list of bounding boxes that the {@code tensor} represents. All dimensions except + * {@code boundingBoxAxis} will be collapsed with order kept. For example, given {@code + * tensor} with shape {1, 4, 10, 2} and {@code boundingBoxAxis = 1}, The result will be a list + * of 20 bounding boxes. + * @throws IllegalArgumentException if size of bounding box dimension (set by {@code + * boundingBoxAxis}) is not 4. + * @throws IllegalArgumentException if {@code boundingBoxAxis} is not in {@code (-(D+1), D)} where + * {@code D} is the number of dimensions of the {@code tensor}. + * @throws IllegalArgumentException if {@code tensor} has data type other than {@link + * DataType#FLOAT32}. + */ + public static List<RectF> convert( + TensorBuffer tensor, + int[] valueIndex, + int boundingBoxAxis, + Type type, + CoordinateType coordinateType, + int height, + int width) { + int[] shape = tensor.getShape(); + checkArgument( + boundingBoxAxis >= -shape.length && boundingBoxAxis < shape.length, + String.format( + "Axis %d is not in range (-(D+1), D), where D is the number of dimensions of input" + + " tensor (shape=%s)", + boundingBoxAxis, Arrays.toString(shape))); + if (boundingBoxAxis < 0) { + boundingBoxAxis = shape.length + boundingBoxAxis; + } + checkArgument( + shape[boundingBoxAxis] == 4, + String.format( + "Size of bounding box dimension %d is not 4. Got %d in shape %s", + boundingBoxAxis, shape[boundingBoxAxis], Arrays.toString(shape))); + checkArgument( + valueIndex.length == 4, + String.format( + "Bounding box index array length %d is not 4. Got index array %s", + valueIndex.length, Arrays.toString(valueIndex))); + checkArgument( + tensor.getDataType() == DataType.FLOAT32, + "Bounding Boxes only create from FLOAT32 buffers. Got: " + tensor.getDataType().name()); + List<RectF> boundingBoxList = new ArrayList<>(); + // Collapse dimensions to {a, 4, b}. So each bounding box could be represent as (i, j), and its + // four values are (i, k, j), where 0 <= k < 4. We can compute the 4 flattened index by + // i * 4b + k * b + j. + int a = 1; + for (int i = 0; i < boundingBoxAxis; i++) { + a *= shape[i]; + } + int b = 1; + for (int i = boundingBoxAxis + 1; i < shape.length; i++) { + b *= shape[i]; + } + float[] values = new float[4]; + ByteBuffer byteBuffer = tensor.getBuffer(); + byteBuffer.rewind(); + FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); + for (int i = 0; i < a; i++) { + for (int j = 0; j < b; j++) { + for (int k = 0; k < 4; k++) { + values[k] = floatBuffer.get((i * 4 + k) * b + j); + } + boundingBoxList.add( + convertOneBoundingBox(values, valueIndex, type, coordinateType, height, width)); + } + } + byteBuffer.rewind(); + return boundingBoxList; + } + + private static RectF convertOneBoundingBox( + float[] values, + int[] valueIndex, + Type type, + CoordinateType coordinateType, + int height, + int width) { + float[] orderedValues = new float[4]; + for (int i = 0; i < 4; i++) { + orderedValues[i] = values[valueIndex[i]]; + } + return convertOneBoundingBox(orderedValues, type, coordinateType, height, width); + } + + private static RectF convertOneBoundingBox( + float[] values, Type type, CoordinateType coordinateType, int height, int width) { + switch (type) { + case BOUNDARIES: + return convertFromBoundaries(values, coordinateType, height, width); + case UPPER_LEFT: + return convertFromUpperLeft(values, coordinateType, height, width); + case CENTER: + return convertFromCenter(values, coordinateType, height, width); + } + throw new IllegalArgumentException("Cannot recognize BoundingBox.Type " + type); + } + + private static RectF convertFromBoundaries( + float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) { + float left = values[0]; + float top = values[1]; + float right = values[2]; + float bottom = values[3]; + return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType); + } + + private static RectF convertFromUpperLeft( + float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) { + float left = values[0]; + float top = values[1]; + float right = values[0] + values[2]; + float bottom = values[1] + values[3]; + return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType); + } + + private static RectF convertFromCenter( + float[] values, CoordinateType coordinateType, int imageHeight, int imageWidth) { + float centerX = values[0]; + float centerY = values[1]; + float w = values[2]; + float h = values[3]; + + float left = centerX - w / 2; + float top = centerY - h / 2; + float right = centerX + w / 2; + float bottom = centerY + h / 2; + return getRectF(left, top, right, bottom, imageHeight, imageWidth, coordinateType); + } + + private static RectF getRectF( + float left, + float top, + float right, + float bottom, + int imageHeight, + int imageWidth, + CoordinateType coordinateType) { + if (coordinateType == CoordinateType.PIXEL) { + return new RectF( + left, top, right, bottom); + } else if (coordinateType == CoordinateType.RATIO) { + return new RectF( + left * imageWidth, top * imageHeight, right * imageWidth, bottom * imageHeight); + } else { + throw new IllegalArgumentException("Cannot convert coordinate type " + coordinateType); + } + } + + // Private constructor to prevent initialization. + private BoundingBoxUtil() {} +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java new file mode 100644 index 00000000..e92d0959 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ColorSpaceType.java @@ -0,0 +1,212 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import java.util.Arrays; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Represents the type of color space of an image. */ +public enum ColorSpaceType { + /** Each pixel has red, green, and blue color components. */ + RGB { + + // The channel axis should always be 3 for RGB images. + private static final int CHANNEL_VALUE = 3; + + @Override + Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) { + return ImageConversions.convertRgbTensorBufferToBitmap(buffer); + } + + @Override + int getChannelValue() { + return CHANNEL_VALUE; + } + + @Override + int[] getNormalizedShape(int[] shape) { + switch (shape.length) { + // The shape is in (h, w, c) format. + case 3: + return insertValue(shape, BATCH_DIM, BATCH_VALUE); + case 4: + return shape; + default: + throw new IllegalArgumentException( + getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape)); + } + } + + @Override + String getShapeInfoMessage() { + return "The shape of a RGB image should be (h, w, c) or (1, h, w, c), and channels" + + " representing R, G, B in order. "; + } + + @Override + Config toBitmapConfig() { + return Config.ARGB_8888; + } + }, + + /** Each pixel is a single element representing only the amount of light. */ + GRAYSCALE { + + // The channel axis should always be 1 for grayscale images. + private static final int CHANNEL_VALUE = 1; + + @Override + Bitmap convertTensorBufferToBitmap(TensorBuffer buffer) { + return ImageConversions.convertGrayscaleTensorBufferToBitmap(buffer); + } + + @Override + int getChannelValue() { + return CHANNEL_VALUE; + } + + @Override + int[] getNormalizedShape(int[] shape) { + switch (shape.length) { + // The shape is in (h, w) format. + case 2: + int[] shapeWithBatch = insertValue(shape, BATCH_DIM, BATCH_VALUE); + return insertValue(shapeWithBatch, CHANNEL_DIM, CHANNEL_VALUE); + case 4: + return shape; + default: + // (1, h, w) and (h, w, 1) are potential grayscale image shapes. However, since they + // both have three dimensions, it will require extra info to differentiate between them. + // Since we haven't encountered real use cases of these two shapes, they are not supported + // at this moment to avoid confusion. We may want to revisit it in the future. + throw new IllegalArgumentException( + getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape)); + } + } + + @Override + String getShapeInfoMessage() { + return "The shape of a grayscale image should be (h, w) or (1, h, w, 1). "; + } + + @Override + Config toBitmapConfig() { + return Config.ALPHA_8; + } + }; + + private static final int BATCH_DIM = 0; // The first element of the normalizaed shape. + private static final int BATCH_VALUE = 1; // The batch axis should always be one. + private static final int HEIGHT_DIM = 1; // The second element of the normalizaed shape. + private static final int WIDTH_DIM = 2; // The third element of the normalizaed shape. + private static final int CHANNEL_DIM = 3; // The fourth element of the normalizaed shape. + + /** + * Converts a bitmap configuration into the corresponding color space type. + * + * @throws IllegalArgumentException if the config is unsupported + */ + static ColorSpaceType fromBitmapConfig(Config config) { + switch (config) { + case ARGB_8888: + return ColorSpaceType.RGB; + case ALPHA_8: + return ColorSpaceType.GRAYSCALE; + default: + throw new IllegalArgumentException( + "Bitmap configuration: " + config + ", is not supported yet."); + } + } + + /** + * Verifies if the given shape matches the color space type. + * + * @throws IllegalArgumentException if {@code shape} does not match the color space type + */ + void assertShape(int[] shape) { + int[] normalizedShape = getNormalizedShape(shape); + checkArgument( + isValidNormalizedShape(normalizedShape), + getShapeInfoMessage() + "The provided image shape is " + Arrays.toString(shape)); + } + + /** + * Converts a {@link TensorBuffer} that represents an image to a Bitmap with the color space type. + * + * @throws IllegalArgumentException if the shape of buffer does not match the color space type + */ + abstract Bitmap convertTensorBufferToBitmap(TensorBuffer buffer); + + /** + * Returns the width of the given shape corresponding to the color space type. + * + * @throws IllegalArgumentException if {@code shape} does not match the color space type + */ + int getWidth(int[] shape) { + assertShape(shape); + return getNormalizedShape(shape)[WIDTH_DIM]; + } + + /** + * Returns the height of the given shape corresponding to the color space type. + * + * @throws IllegalArgumentException if {@code shape} does not match the color space type + */ + int getHeight(int[] shape) { + assertShape(shape); + return getNormalizedShape(shape)[HEIGHT_DIM]; + } + + abstract int getChannelValue(); + + /** + * Gets the normalized shape in the form of (1, h, w, c). Sometimes, a given shape may not have + * batch or channel axis. + */ + abstract int[] getNormalizedShape(int[] shape); + + abstract String getShapeInfoMessage(); + + /** Converts the color space type to the corresponding bitmap config. */ + abstract Config toBitmapConfig(); + + /** Inserts a value at the specified position and return the new array. */ + private static int[] insertValue(int[] array, int pos, int value) { + int[] newArray = new int[array.length + 1]; + for (int i = 0; i < pos; i++) { + newArray[i] = array[i]; + } + newArray[pos] = value; + for (int i = pos + 1; i < newArray.length; i++) { + newArray[i] = array[i - 1]; + } + return newArray; + } + + protected boolean isValidNormalizedShape(int[] shape) { + if (shape[BATCH_DIM] == BATCH_VALUE + && shape[HEIGHT_DIM] > 0 + && shape[WIDTH_DIM] > 0 + && shape[CHANNEL_DIM] == getChannelValue()) { + return true; + } + return false; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java new file mode 100644 index 00000000..1a145de7 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageContainer.java @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import android.graphics.Bitmap; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Handles image conversion across different image types. + * + * <p>An {@link ImageContainer} should support the conversion between the underlying image format to + * the following image types: + * + * <ul> + * <li>{@link Bitmap} + * <li>{@link TensorBuffer} of the specified data type. + * </ul> + */ +interface ImageContainer { + + /** Performs deep copy of the {@link ImageContainer}. */ + ImageContainer clone(); + + /** Returns the width of the image. */ + int getWidth(); + + /** Returns the height of the image. */ + int getHeight(); + + /** Gets the {@link Bitmap} representation of the underlying image format. */ + Bitmap getBitmap(); + + /** + * Gets the {@link TensorBuffer} representation with the specific {@code dataType} of the + * underlying image format. + */ + TensorBuffer getTensorBuffer(DataType dataType); + + /** Returns the color space type of the image. */ + ColorSpaceType getColorSpaceType(); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java new file mode 100644 index 00000000..d6e567a2 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageConversions.java @@ -0,0 +1,146 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import android.graphics.Bitmap; +import android.graphics.Color; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Implements some stateless image conversion methods. + * + * <p>This class is an internal helper for {@link org.tensorflow.lite.support.image}. + */ +class ImageConversions { + + /** + * Converts a {@link TensorBuffer} that represents a RGB image to an ARGB_8888 Bitmap. + * + * <p>Data in buffer will be converted into integer to match the Bitmap API. + * + * @param buffer a RGB image. Its shape should be either (h, w, 3) or (1, h, w, 3) + * @throws IllegalArgumentException if the shape of buffer is neither (h, w, 3) nor (1, h, w, 3) + */ + static Bitmap convertRgbTensorBufferToBitmap(TensorBuffer buffer) { + int[] shape = buffer.getShape(); + ColorSpaceType rgb = ColorSpaceType.RGB; + rgb.assertShape(shape); + + int h = rgb.getHeight(shape); + int w = rgb.getWidth(shape); + Bitmap bitmap = Bitmap.createBitmap(w, h, rgb.toBitmapConfig()); + + // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time. + int[] intValues = new int[w * h]; + int[] rgbValues = buffer.getIntArray(); + for (int i = 0, j = 0; i < intValues.length; i++) { + int r = rgbValues[j++]; + int g = rgbValues[j++]; + int b = rgbValues[j++]; + intValues[i] = Color.rgb(r, g, b); + } + bitmap.setPixels(intValues, 0, w, 0, 0, w, h); + + return bitmap; + } + + /** + * Converts a {@link TensorBuffer} that represents a grayscale image to an ALPHA_8 Bitmap. + * + * <p>Data in buffer will be converted into integer to match the Bitmap API. + * + * @param buffer a grayscale image. Its shape should be either (h, w) or (1, h, w) + * @throws IllegalArgumentException if the shape of buffer is neither (h, w) nor (1, h, w, 1) + */ + static Bitmap convertGrayscaleTensorBufferToBitmap(TensorBuffer buffer) { + // Convert buffer into Uint8 as needed. + TensorBuffer uint8Buffer = + buffer.getDataType() == DataType.UINT8 + ? buffer + : TensorBuffer.createFrom(buffer, DataType.UINT8); + + int[] shape = uint8Buffer.getShape(); + ColorSpaceType grayscale = ColorSpaceType.GRAYSCALE; + grayscale.assertShape(shape); + + // Even though `Bitmap.createBitmap(int[] colors, int width, int height, Bitmap.Config config)` + // seems to work for internal Android testing framework, but it actually doesn't work for the + // real Android environment. + // + // The only reliable way to create an ALPHA_8 Bitmap is to use `copyPixelsFromBuffer()` to load + // the pixels from a ByteBuffer, and then use `copyPixelsToBuffer` to read out. + // Note: for ALPHA_8 Bitmap, methods such as, `setPixels()` and `getPixels()` do not work. + Bitmap bitmap = + Bitmap.createBitmap( + grayscale.getWidth(shape), grayscale.getHeight(shape), grayscale.toBitmapConfig()); + uint8Buffer.getBuffer().rewind(); + bitmap.copyPixelsFromBuffer(uint8Buffer.getBuffer()); + return bitmap; + } + + /** + * Converts an Image in a Bitmap to a TensorBuffer (3D Tensor: Width-Height-Channel) whose memory + * is already allocated, or could be dynamically allocated. + * + * @param bitmap The Bitmap object representing the image. Currently we only support ARGB_8888 + * config. + * @param buffer The destination of the conversion. Needs to be created in advance. If it's + * fixed-size, its flat size should be w*h*3. + * @throws IllegalArgumentException if the buffer is fixed-size, but the size doesn't match. + */ + static void convertBitmapToTensorBuffer(Bitmap bitmap, TensorBuffer buffer) { + int w = bitmap.getWidth(); + int h = bitmap.getHeight(); + int[] intValues = new int[w * h]; + bitmap.getPixels(intValues, 0, w, 0, 0, w, h); + // TODO(b/138904567): Find a way to avoid creating multiple intermediate buffers every time. + int flatSize = w * h * 3; + int[] shape = new int[] {h, w, 3}; + switch (buffer.getDataType()) { + case UINT8: + byte[] byteArr = new byte[w * h * 3]; + for (int i = 0, j = 0; i < intValues.length; i++) { + byteArr[j++] = (byte) ((intValues[i] >> 16) & 0xff); + byteArr[j++] = (byte) ((intValues[i] >> 8) & 0xff); + byteArr[j++] = (byte) (intValues[i] & 0xff); + } + ByteBuffer byteBuffer = ByteBuffer.allocateDirect(flatSize); + byteBuffer.order(ByteOrder.nativeOrder()); + byteBuffer.put(byteArr); + buffer.loadBuffer(byteBuffer, shape); + break; + case FLOAT32: + float[] floatArr = new float[w * h * 3]; + for (int i = 0, j = 0; i < intValues.length; i++) { + floatArr[j++] = (float) ((intValues[i] >> 16) & 0xff); + floatArr[j++] = (float) ((intValues[i] >> 8) & 0xff); + floatArr[j++] = (float) (intValues[i] & 0xff); + } + buffer.loadArray(floatArr, shape); + break; + default: + // Should never happen. + throw new IllegalStateException( + "The type of TensorBuffer, " + buffer.getBuffer() + ", is unsupported."); + } + } + + // Hide the constructor as the class is static. + private ImageConversions() {} +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java new file mode 100644 index 00000000..1e546634 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageOperator.java @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import android.graphics.PointF; +import org.tensorflow.lite.support.common.Operator; + +/** Operates a TensorImage object. Used in ImageProcessor. */ +public interface ImageOperator extends Operator<TensorImage> { + /** @see org.tensorflow.lite.support.common.Operator#apply(java.lang.Object) */ + @Override + TensorImage apply(TensorImage image); + + /** Computes the width of the expected output image when input image size is given. */ + int getOutputImageWidth(int inputImageHeight, int inputImageWidth); + + /** Computes the height of the expected output image when input image size is given. */ + int getOutputImageHeight(int inputImageHeight, int inputImageWidth); + + /** + * Transforms a point from coordinates system of the result image back to the one of the input + * image. + * + * @param point the point from the result coordinates system. + * @param inputImageHeight the height of input image. + * @param inputImageWidth the width of input image. + * @return the point with the coordinates from the coordinates system of the input image. + */ + PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java new file mode 100644 index 00000000..e1ef1309 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ImageProcessor.java @@ -0,0 +1,198 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import android.graphics.PointF; +import android.graphics.RectF; +import java.util.ArrayList; +import java.util.List; +import java.util.ListIterator; +import org.tensorflow.lite.support.common.Operator; +import org.tensorflow.lite.support.common.SequentialProcessor; +import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.image.ops.Rot90Op; +import org.tensorflow.lite.support.image.ops.TensorOperatorWrapper; + +/** + * ImageProcessor is a helper class for preprocessing and postprocessing {@link TensorImage}. It + * could transform a {@link TensorImage} to another by executing a chain of {@link ImageOperator}. + * + * <p>Example Usage: + * + * <pre> + * ImageProcessor processor = new ImageProcessor.Builder() + * .add(new ResizeOp(224, 224, ResizeMethod.NEAREST_NEIGHBOR) + * .add(new Rot90Op()) + * .add(new NormalizeOp(127.5f, 127.5f)) + * .build(); + * TensorImage anotherTensorImage = processor.process(tensorImage); + * </pre> + * + * <p><b>WARNING:</b> Instances of an {@code ImageProcessor} are <b>not</b> thread-safe with {@link + * #updateNumberOfRotations}. Updating the number of rotations and then processing images (using + * {@link #process}) must be protected from concurrent access. It is recommended to create separate + * {@code ImageProcessor} instances for each thread. If multiple threads access a {@code + * ImageProcessor} concurrently, it must be synchronized externally. + * + * @see ImageProcessor.Builder to build a {@link ImageProcessor} instance + * @see ImageProcessor#process(TensorImage) to apply the processor on a {@link TensorImage} + */ +public class ImageProcessor extends SequentialProcessor<TensorImage> { + private ImageProcessor(Builder builder) { + super(builder); + } + + /** + * Transforms a point from coordinates system of the result image back to the one of the input + * image. + * + * @param point the point from the result coordinates system. + * @param inputImageHeight the height of input image. + * @param inputImageWidth the width of input image. + * @return the point with the coordinates from the coordinates system of the input image. + */ + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { + List<Integer> widths = new ArrayList<>(); + List<Integer> heights = new ArrayList<>(); + int currentWidth = inputImageWidth; + int currentHeight = inputImageHeight; + for (Operator<TensorImage> op : operatorList) { + widths.add(currentWidth); + heights.add(currentHeight); + ImageOperator imageOperator = (ImageOperator) op; + int newHeight = imageOperator.getOutputImageHeight(currentHeight, currentWidth); + int newWidth = imageOperator.getOutputImageWidth(currentHeight, currentWidth); + currentHeight = newHeight; + currentWidth = newWidth; + } + ListIterator<Operator<TensorImage>> opIterator = operatorList.listIterator(operatorList.size()); + ListIterator<Integer> widthIterator = widths.listIterator(widths.size()); + ListIterator<Integer> heightIterator = heights.listIterator(heights.size()); + while (opIterator.hasPrevious()) { + ImageOperator imageOperator = (ImageOperator) opIterator.previous(); + int height = heightIterator.previous(); + int width = widthIterator.previous(); + point = imageOperator.inverseTransform(point, height, width); + } + return point; + } + + /** + * Transforms a rectangle from coordinates system of the result image back to the one of the input + * image. + * + * @param rect the rectangle from the result coordinates system. + * @param inputImageHeight the height of input image. + * @param inputImageWidth the width of input image. + * @return the rectangle with the coordinates from the coordinates system of the input image. + */ + public RectF inverseTransform(RectF rect, int inputImageHeight, int inputImageWidth) { + // when rotation is involved, corner order may change - top left changes to bottom right, .etc + PointF p1 = + inverseTransform(new PointF(rect.left, rect.top), inputImageHeight, inputImageWidth); + PointF p2 = + inverseTransform(new PointF(rect.right, rect.bottom), inputImageHeight, inputImageWidth); + return new RectF( + Math.min(p1.x, p2.x), Math.min(p1.y, p2.y), Math.max(p1.x, p2.x), Math.max(p1.y, p2.y)); + } + + /** + * The Builder to create an ImageProcessor, which could be executed later. + * + * @see #add(TensorOperator) to add a general TensorOperator + * @see #add(ImageOperator) to add an ImageOperator + * @see #build() complete the building process and get a built Processor + */ + public static class Builder extends SequentialProcessor.Builder<TensorImage> { + public Builder() { + super(); + } + + /** + * Adds an {@link ImageOperator} into the Operator chain. + * + * @param op the Operator instance to be executed then + */ + public Builder add(ImageOperator op) { + super.add(op); + return this; + } + + /** + * Adds a {@link TensorOperator} into the Operator chain. In execution, the processor calls + * {@link TensorImage#getTensorBuffer()} to transform the {@link TensorImage} by transforming + * the underlying {@link org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. + * + * @param op the Operator instance to be executed then + */ + public Builder add(TensorOperator op) { + return add(new TensorOperatorWrapper(op)); + } + + /** Completes the building process and gets the {@link ImageProcessor} instance. */ + @Override + public ImageProcessor build() { + return new ImageProcessor(this); + } + } + + /** + * Updates the number of rotations for the first {@link Rot90Op} in this {@link ImageProcessor}. + * + * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and + * then processing images (using {@link #process}) must be protected from concurrent access with + * additional synchronization. + * + * @param k the number of rotations + * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link + * ImageProcessor} + */ + public void updateNumberOfRotations(int k) { + updateNumberOfRotations(k, /*occurrence=*/ 0); + } + + /** + * Updates the number of rotations for the {@link Rot90Op} specified by {@code occurrence} in this + * {@link ImageProcessor}. + * + * <p><b>WARNING:</b>this method is <b>not</b> thread-safe. Updating the number of rotations and + * then processing images (using {@link #process}) must be protected from concurrent access with + * additional synchronization. + * + * @param k the number of rotations + * @param occurrence the index of perticular {@link Rot90Op} in this {@link ImageProcessor}. For + * example, if the second {@link Rot90Op} needs to be updated, {@code occurrence} should be + * set to 1. + * @throws IndexOutOfBoundsException if {@code occurrence} is negative or is not less than the + * number of {@link Rot90Op} in this {@link ImageProcessor} + * @throws IllegalStateException if {@link Rot90Op} has not been added to this {@link + * ImageProcessor} + */ + public synchronized void updateNumberOfRotations(int k, int occurrence) { + SupportPreconditions.checkState( + operatorIndex.containsKey(Rot90Op.class.getName()), + "The Rot90Op has not been added to the ImageProcessor."); + + List<Integer> indexes = operatorIndex.get(Rot90Op.class.getName()); + SupportPreconditions.checkElementIndex(occurrence, indexes.size(), "occurrence"); + + // The index of the Rot90Op to be replaced in operatorList. + int index = indexes.get(occurrence); + Rot90Op newRot = new Rot90Op(k); + operatorList.set(index, newRot); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java new file mode 100644 index 00000000..d047a8e0 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorBufferContainer.java @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import android.graphics.Bitmap; +import android.util.Log; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Holds a {@link TensorBuffer} and converts it to other image formats as needed. */ +final class TensorBufferContainer implements ImageContainer { + + private final TensorBuffer buffer; + private final ColorSpaceType colorSpaceType; + private static final String TAG = TensorBufferContainer.class.getSimpleName(); + + /** + * Creates a {@link TensorBufferContainer} object with the specified {@link + * TensorImage#ColorSpaceType}. + * + * @throws IllegalArgumentException if the shape of the {@link TensorBuffer} does not match the + * specified color space type + */ + static TensorBufferContainer create(TensorBuffer buffer, ColorSpaceType colorSpaceType) { + return new TensorBufferContainer(buffer, colorSpaceType); + } + + private TensorBufferContainer(TensorBuffer buffer, ColorSpaceType colorSpaceType) { + colorSpaceType.assertShape(buffer.getShape()); + this.buffer = buffer; + this.colorSpaceType = colorSpaceType; + } + + @Override + public TensorBufferContainer clone() { + return create(TensorBuffer.createFrom(buffer, buffer.getDataType()), colorSpaceType); + } + + @Override + public Bitmap getBitmap() { + if (buffer.getDataType() != DataType.UINT8) { + // Print warning instead of throwing an exception. When using float models, users may want to + // convert the resulting float image into Bitmap. That's fine to do so, as long as they are + // aware of the potential accuracy lost when casting to uint8. + Log.w( + TAG, + "<Warning> TensorBufferContainer is holding a non-uint8 image. The conversion to Bitmap" + + " will cause numeric casting and clamping on the data value."); + } + + return colorSpaceType.convertTensorBufferToBitmap(buffer); + } + + @Override + public TensorBuffer getTensorBuffer(DataType dataType) { + // If the data type of buffer is desired, return it directly. Not making a defensive copy for + // performance considerations. During image processing, users may need to set and get the + // TensorBuffer many times. + // Otherwise, create another one with the expected data type. + return buffer.getDataType() == dataType ? buffer : TensorBuffer.createFrom(buffer, dataType); + } + + @Override + public int getWidth() { + return colorSpaceType.getWidth(buffer.getShape()); + } + + @Override + public int getHeight() { + return colorSpaceType.getHeight(buffer.getShape()); + } + + @Override + public ColorSpaceType getColorSpaceType() { + return colorSpaceType; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java new file mode 100644 index 00000000..96cae716 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/TensorImage.java @@ -0,0 +1,312 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image; + +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; + +import android.graphics.Bitmap; +import java.nio.ByteBuffer; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * TensorImage is the wrapper class for Image object. When using image processing utils in + * TFLite.support library, it's common to convert image objects in variant types to TensorImage at + * first. + * + * <p>At present, only RGB images are supported, and the A channel is always ignored. + * + * <p>Details of data storage: a {@link TensorImage} object may have 2 potential sources of truth: a + * {@link Bitmap} or a {@link TensorBuffer}. {@link TensorImage} maintains the state and only + * converts one to the other when needed. A typical use case of {@link TensorImage} is to first load + * a {@link Bitmap} image, then process it using {@link ImageProcessor}, and finally get the + * underlying {@link ByteBuffer} of the {@link TensorBuffer} and feed it into the TFLite + * interpreter. + * + * <p>IMPORTANT: to achieve the best performance, {@link TensorImage} avoids copying data whenever + * it's possible. Therefore, it doesn't own its data. Callers should not modify data objects those + * are passed to {@link TensorImage#load(Bitmap)} or {@link TensorImage#load(TensorBuffer)}. + * + * <p>IMPORTANT: all methods are not proved thread-safe. + * + * @see ImageProcessor which is often used for transforming a {@link TensorImage}. + */ +// TODO(b/138907116): Support loading images from TensorBuffer with properties. +// TODO(b/138905544): Support directly loading RGBBytes, YUVBytes and other types if necessary. +public class TensorImage { + + private final DataType dataType; + private ImageContainer container = null; + + /** + * Initializes a {@link TensorImage} object. + * + * <p>Note: the data type of this {@link TensorImage} is {@link DataType#UINT8}. Use {@link + * #TensorImage(DataType)} if other data types are preferred. + */ + public TensorImage() { + this(DataType.UINT8); + } + + /** + * Initializes a {@link TensorImage} object with the specified data type. + * + * <p>When getting a {@link TensorBuffer} or a {@link ByteBuffer} from this {@link TensorImage}, + * such as using {@link #getTensorBuffer} and {@link #getBuffer}, the data values will be + * converted to the specified data type. + * + * <p>Note: the shape of a {@link TensorImage} is not fixed. It can be adjusted to the shape of + * the image being loaded to this {@link TensorImage}. + * + * @param dataType the expected data type of the resulting {@link TensorBuffer}. The type is + * always fixed during the lifetime of the {@link TensorImage}. To convert the data type, use + * {@link #createFrom(TensorImage, DataType)} to create a copy and convert data type at the + * same time. + * @throws IllegalArgumentException if {@code dataType} is neither {@link DataType#UINT8} nor + * {@link DataType#FLOAT32} + */ + public TensorImage(DataType dataType) { + checkArgument( + dataType == DataType.UINT8 || dataType == DataType.FLOAT32, + "Illegal data type for TensorImage: Only FLOAT32 and UINT8 are accepted"); + this.dataType = dataType; + } + + /** + * Initializes a {@link TensorImage} object of {@link DataType#UINT8} with a {@link Bitmap} . + * + * @see TensorImage#load(Bitmap) for reusing the object when it's expensive to create objects + * frequently, because every call of {@code fromBitmap} creates a new {@link TensorImage}. + */ + public static TensorImage fromBitmap(Bitmap bitmap) { + TensorImage image = new TensorImage(); + image.load(bitmap); + return image; + } + + /** + * Creates a deep-copy of a given {@link TensorImage} with the desired data type. + * + * @param src the {@link TensorImage} to copy from + * @param dataType the expected data type of newly created {@link TensorImage} + * @return a {@link TensorImage} whose data is copied from {@code src} and data type is {@code + * dataType} + */ + public static TensorImage createFrom(TensorImage src, DataType dataType) { + TensorImage dst = new TensorImage(dataType); + dst.container = src.container.clone(); + return dst; + } + + /** + * Loads a {@link Bitmap} image object into this {@link TensorImage}. + * + * <p>Note: if the {@link TensorImage} has data type other than {@link DataType#UINT8}, numeric + * casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link + * #getBuffer}, where the {@link Bitmap} will be converted into a {@link TensorBuffer}. + * + * <p>Important: when loading a bitmap, DO NOT MODIFY the bitmap from the caller side anymore. The + * {@link TensorImage} object will rely on the bitmap. It will probably modify the bitmap as well. + * In this method, we perform a zero-copy approach for that bitmap, by simply holding its + * reference. Use {@code bitmap.copy(bitmap.getConfig(), true)} to create a copy if necessary. + * + * <p>Note: to get the best performance, please load images in the same shape to avoid memory + * re-allocation. + * + * @throws IllegalArgumentException if {@code bitmap} is not in ARGB_8888 + */ + public void load(Bitmap bitmap) { + container = BitmapContainer.create(bitmap); + } + + /** + * Loads a float array as RGB pixels into this {@link TensorImage}, representing the pixels + * inside. + * + * <p>Note: if the {@link TensorImage} has a data type other than {@link DataType#FLOAT32}, + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link + * #getBuffer}. + * + * @param pixels the RGB pixels representing the image + * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3) + * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) + */ + public void load(float[] pixels, int[] shape) { + TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); + buffer.loadArray(pixels, shape); + load(buffer); + } + + /** + * Loads an int array as RGB pixels into this {@link TensorImage}, representing the pixels inside. + * + * <p>Note: numeric casting and clamping will be applied to convert the values into the data type + * of this {@link TensorImage} when calling {@link #getTensorBuffer} and {@link #getBuffer}. + * + * @param pixels the RGB pixels representing the image + * @param shape the shape of the image, should either in form (h, w, 3), or in form (1, h, w, 3) + * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) + */ + public void load(int[] pixels, int[] shape) { + TensorBuffer buffer = TensorBuffer.createDynamic(getDataType()); + buffer.loadArray(pixels, shape); + load(buffer); + } + + /** + * Loads a {@link TensorBuffer} containing pixel values. The color layout should be RGB. + * + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link + * #getBuffer}. + * + * @param buffer the {@link TensorBuffer} to be loaded. Its shape should be either (h, w, 3) or + * (1, h, w, 3) + * @throws IllegalArgumentException if the shape is neither (h, w, 3) nor (1, h, w, 3) + */ + public void load(TensorBuffer buffer) { + load(buffer, ColorSpaceType.RGB); + } + + /** + * Loads a {@link TensorBuffer} containing pixel values with the specific {@link ColorSapceType}. + * + * <p>Note: if the data type of {@code buffer} does not match that of this {@link TensorImage}, + * numeric casting and clamping will be applied when calling {@link #getTensorBuffer} and {@link + * #getBuffer}. + * + * @throws IllegalArgumentException if the shape of buffer does not match the color space type + * @see ColorSpaceType#assertShape + */ + public void load(TensorBuffer buffer, ColorSpaceType colorSpaceType) { + container = TensorBufferContainer.create(buffer, colorSpaceType); + } + + /** + * Returns a {@link Bitmap} representation of this {@link TensorImage}. + * + * <p>Numeric casting and clamping will be applied if the stored data is not uint8. + * + * <p>Note that, the reliable way to get pixels from an {@code ALPHA_8} Bitmap is to use {@code + * copyPixelsToBuffer}. Bitmap methods such as, `setPixels()` and `getPixels` do not work. + * + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance + * concern, but if modification is necessary, please make a copy. + * + * @return a reference to a {@link Bitmap} in {@code ARGB_8888} config ("A" channel is always + * opaque) or in {@code ALPHA_8}, depending on the {@link ColorSpaceType} of this {@link + * TensorBuffer}. + * @throws IllegalStateException if the {@link TensorImage} never loads data + */ + public Bitmap getBitmap() { + if (container == null) { + throw new IllegalStateException("No image has been loaded yet."); + } + + return container.getBitmap(); + } + + /** + * Returns a {@link ByteBuffer} representation of this {@link TensorImage} with the expected data + * type. + * + * <p>Numeric casting and clamping will be applied if the stored data is different from the data + * type of the {@link TensorImage}. + * + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance + * concern, but if modification is necessary, please make a copy. + * + * <p>It's essentially a short cut for {@code getTensorBuffer().getBuffer()}. + * + * @return a reference to a {@link ByteBuffer} which holds the image data + * @throws IllegalStateException if the {@link TensorImage} never loads data + */ + public ByteBuffer getBuffer() { + return getTensorBuffer().getBuffer(); + } + + /** + * Returns a {@link TensorBuffer} representation of this {@link TensorImage} with the expected + * data type. + * + * <p>Numeric casting and clamping will be applied if the stored data is different from the data + * type of the {@link TensorImage}. + * + * <p>Important: it's only a reference. DO NOT MODIFY. We don't create a copy here for performance + * concern, but if modification is necessary, please make a copy. + * + * @return a reference to a {@link TensorBuffer} which holds the image data + * @throws IllegalStateException if the {@link TensorImage} never loads data + */ + public TensorBuffer getTensorBuffer() { + if (container == null) { + throw new IllegalStateException("No image has been loaded yet."); + } + + return container.getTensorBuffer(dataType); + } + + /** + * Gets the data type of this {@link TensorImage}. + * + * @return a data type. Currently only {@link DataType#UINT8} and {@link DataType#FLOAT32} are + * supported. + */ + public DataType getDataType() { + return dataType; + } + + /** + * Gets the color space type of this {@link TensorImage}. + * + * @throws IllegalStateException if the {@link TensorImage} never loads data + */ + public ColorSpaceType getColorSpaceType() { + if (container == null) { + throw new IllegalStateException("No image has been loaded yet."); + } + + return container.getColorSpaceType(); + } + + /** + * Gets the image width. + * + * @throws IllegalStateException if the {@link TensorImage} never loads data + * @throws IllegalArgumentException if the underlying data is corrupted + */ + public int getWidth() { + if (container == null) { + throw new IllegalStateException("No image has been loaded yet."); + } + + return container.getWidth(); + } + + /** + * Gets the image height. + * + * @throws IllegalStateException if the {@link TensorImage} never loads data + * @throws IllegalArgumentException if the underlying data is corrupted + */ + public int getHeight() { + if (container == null) { + throw new IllegalStateException("No image has been loaded yet."); + } + + return container.getHeight(); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java new file mode 100644 index 00000000..35606dd6 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeOp.java @@ -0,0 +1,89 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image.ops; + +import android.graphics.Bitmap; +import android.graphics.PointF; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.image.ImageOperator; +import org.tensorflow.lite.support.image.TensorImage; + +/** + * As a computation unit for processing images, it can resize an image to user-specified size. + * + * <p>It interpolates pixels when image is stretched, and discards pixels when image is compressed. + * + * @see ResizeWithCropOrPadOp for resizing without content distortion. + */ +public class ResizeOp implements ImageOperator { + + /** Algorithms for resizing. */ + public enum ResizeMethod { + BILINEAR, + NEAREST_NEIGHBOR + } + + private final int targetHeight; + private final int targetWidth; + private final boolean useBilinear; + + /** + * Creates a ResizeOp which can resize images to specified size in specified method. + * + * @param targetHeight: The expected height of resized image. + * @param targetWidth: The expected width of resized image. + * @param resizeMethod: The algorithm to use for resizing. Options: {@link ResizeMethod} + */ + public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod) { + this.targetHeight = targetHeight; + this.targetWidth = targetWidth; + useBilinear = (resizeMethod == ResizeMethod.BILINEAR); + } + + /** + * Applies the defined resizing on given image and returns the result. + * + * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance + * with the output. + * + * @param image input image. + * @return output image. + */ + @Override + @NonNull + public TensorImage apply(@NonNull TensorImage image) { + Bitmap scaled = + Bitmap.createScaledBitmap(image.getBitmap(), targetWidth, targetHeight, useBilinear); + image.load(scaled); + return image; + } + + @Override + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { + return targetHeight; + } + + @Override + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { + return targetWidth; + } + + @Override + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { + return new PointF( + point.x * inputImageWidth / targetWidth, point.y * inputImageHeight / targetHeight); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java new file mode 100644 index 00000000..404429ef --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/ResizeWithCropOrPadOp.java @@ -0,0 +1,125 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image.ops; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Canvas; +import android.graphics.PointF; +import android.graphics.Rect; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.image.ImageOperator; +import org.tensorflow.lite.support.image.TensorImage; + +/** + * As a computation unit for processing images, it could resize image to predefined size. + * + * <p>It will not stretch or compress the content of image. However, to fit the new size, it crops + * or pads pixels. When it crops image, it performs a center-crop; when it pads pixels, it performs + * a zero-padding. + * + * @see ResizeOp for reszing images while stretching / compressing the content. + */ +public class ResizeWithCropOrPadOp implements ImageOperator { + private final int targetHeight; + private final int targetWidth; + private final Bitmap output; + + /** + * Creates a ResizeWithCropOrPadOp which could crop/pad images to specified size. It adopts + * center-crop and zero-padding. + * + * @param targetHeight: The expected height of cropped/padded image. + * @param targetWidth: The expected width of cropped/padded image. + */ + public ResizeWithCropOrPadOp(int targetHeight, int targetWidth) { + this.targetHeight = targetHeight; + this.targetWidth = targetWidth; + output = Bitmap.createBitmap(this.targetWidth, this.targetHeight, Config.ARGB_8888); + } + + /** + * Applies the defined resizing with cropping or/and padding on given image and returns the + * result. + * + * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance + * with the output. + * + * @param image input image. + * @return output image. + */ + @Override + @NonNull + public TensorImage apply(@NonNull TensorImage image) { + Bitmap input = image.getBitmap(); + int srcL; + int srcR; + int srcT; + int srcB; + int dstL; + int dstR; + int dstT; + int dstB; + int w = input.getWidth(); + int h = input.getHeight(); + if (targetWidth > w) { // padding + srcL = 0; + srcR = w; + dstL = (targetWidth - w) / 2; + dstR = dstL + w; + } else { // cropping + dstL = 0; + dstR = targetWidth; + srcL = (w - targetWidth) / 2; + srcR = srcL + targetWidth; + } + if (targetHeight > h) { // padding + srcT = 0; + srcB = h; + dstT = (targetHeight - h) / 2; + dstB = dstT + h; + } else { // cropping + dstT = 0; + dstB = targetHeight; + srcT = (h - targetHeight) / 2; + srcB = srcT + targetHeight; + } + Rect src = new Rect(srcL, srcT, srcR, srcB); + Rect dst = new Rect(dstL, dstT, dstR, dstB); + new Canvas(output).drawBitmap(input, src, dst, null); + image.load(output); + return image; + } + + @Override + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { + return targetHeight; + } + + @Override + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { + return targetWidth; + } + + @Override + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { + return transformImpl(point, targetHeight, targetWidth, inputImageHeight, inputImageWidth); + } + + private static PointF transformImpl(PointF point, int srcH, int srcW, int dstH, int dstW) { + return new PointF(point.x + (dstW - srcW) / 2, point.y + (dstH - srcH) / 2); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java new file mode 100644 index 00000000..2fa22937 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/Rot90Op.java @@ -0,0 +1,103 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image.ops; + +import android.graphics.Bitmap; +import android.graphics.Matrix; +import android.graphics.PointF; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.image.ImageOperator; +import org.tensorflow.lite.support.image.TensorImage; + +/** Rotates image counter-clockwise. */ +public class Rot90Op implements ImageOperator { + + private final int numRotation; + + /** Creates a Rot90 Op which will rotate image by 90 degree counter-clockwise. */ + public Rot90Op() { + this(1); + } + + /** + * Creates a Rot90 Op which will rotate image by 90 degree for {@code k} times counter-clockwise. + * + * @param k: The number of times the image is rotated by 90 degrees. If it's positive, the image + * will be rotated counter-clockwise. If it's negative, the op will rotate image clockwise. + */ + public Rot90Op(int k) { + numRotation = k % 4; + } + + /** + * Applies the defined rotation on given image and returns the result. + * + * <p>Note: the content of input {@code image} will change, and {@code image} is the same instance + * with the output. + * + * @param image input image. + * @return output image. + */ + @NonNull + @Override + public TensorImage apply(@NonNull TensorImage image) { + Bitmap input = image.getBitmap(); + if (numRotation == 0) { + return image; + } + int w = input.getWidth(); + int h = input.getHeight(); + Matrix matrix = new Matrix(); + matrix.postTranslate(w * 0.5f, h * 0.5f); + matrix.postRotate(-90 * numRotation); + int newW = (numRotation % 2 == 0) ? w : h; + int newH = (numRotation % 2 == 0) ? h : w; + matrix.postTranslate(newW * 0.5f, newH * 0.5f); + Bitmap output = Bitmap.createBitmap(input, 0, 0, w, h, matrix, false); + image.load(output); + return image; + } + + @Override + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { + return (numRotation % 2 == 0) ? inputImageHeight : inputImageWidth; + } + + @Override + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { + return (numRotation % 2 == 0) ? inputImageWidth : inputImageHeight; + } + + @Override + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { + int inverseNumRotation = (4 - numRotation) % 4; + int height = getOutputImageHeight(inputImageHeight, inputImageWidth); + int width = getOutputImageWidth(inputImageHeight, inputImageWidth); + return transformImpl(point, height, width, inverseNumRotation); + } + + private static PointF transformImpl(PointF point, int height, int width, int numRotation) { + if (numRotation == 0) { + return point; + } else if (numRotation == 1) { + return new PointF(point.y, width - point.x); + } else if (numRotation == 2) { + return new PointF(width - point.x, height - point.y); + } else { // numRotation == 3 + return new PointF(height - point.y, point.x); + } + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java new file mode 100644 index 00000000..420018dd --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/image/ops/TensorOperatorWrapper.java @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.image.ops; + +import android.graphics.PointF; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.common.TensorOperator; +import org.tensorflow.lite.support.image.ImageOperator; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * The adapter that makes a TensorOperator able to run with TensorImage. + * + * @see org.tensorflow.lite.support.common.TensorOperator + * @see org.tensorflow.lite.support.image.TensorImage + */ +public class TensorOperatorWrapper implements ImageOperator { + + private final TensorOperator tensorOp; + + /** + * Wraps a {@link TensorOperator} object as an {@link ImageOperator}, so that the {@link + * TensorOperator} could handle {@link TensorImage} objects by handling its underlying {@link + * org.tensorflow.lite.support.tensorbuffer.TensorBuffer}. + * + * <p>Requirement: The {@code op} should not change coordinate system when applied on an image. + * + * @param op The created operator. + */ + public TensorOperatorWrapper(TensorOperator op) { + tensorOp = op; + } + + @Override + @NonNull + public TensorImage apply(@NonNull TensorImage image) { + SupportPreconditions.checkNotNull(image, "Op cannot apply on null image."); + TensorBuffer resBuffer = tensorOp.apply(image.getTensorBuffer()); + // Some ops may change the data type of the underlying TensorBuffer, such as CastOp. Therefore, + // need to create a new TensorImage with the correct data type. + TensorImage resImage = new TensorImage(resBuffer.getDataType()); + resImage.load(resBuffer); + return resImage; + } + + @Override + public int getOutputImageHeight(int inputImageHeight, int inputImageWidth) { + return inputImageHeight; + } + + @Override + public int getOutputImageWidth(int inputImageHeight, int inputImageWidth) { + return inputImageWidth; + } + + @Override + public PointF inverseTransform(PointF point, int inputImageHeight, int inputImageWidth) { + return point; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java new file mode 100644 index 00000000..5b043a9f --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.label; + +import java.util.Objects; +import org.tensorflow.lite.annotations.UsedByReflection; + +/** + * Category is a util class, contains a label, its display name and a float value as score. + * Typically it's used as result of classification tasks. + */ +@UsedByReflection("TFLiteSupport/Task") +public final class Category { + private final String label; + private final String displayName; + private final float score; + + /** + * Constructs a {@link Category} object. + * + * @param displayName the display name of the label, which may be translated for different + * locales. For exmaple, a label, "apple", may be translated into Spanish for display purpose, + * so that the displayName is "manzana". + */ + @UsedByReflection("TFLiteSupport/Task") + public static Category create(String label, String displayName, float score) { + return new Category(label, displayName, score); + } + + @UsedByReflection("TFLiteSupport/Task") + /** Constructs a {@link Category} object with an empty displayName. */ + public Category(String label, float score) { + this(label, /*displayName=*/ "", score); + } + + private Category(String label, String displayName, float score) { + this.label = label; + this.displayName = displayName; + this.score = score; + } + + /** Gets the reference of category's label. */ + public String getLabel() { + return label; + } + + /** + * Gets the reference of category's displayName, a name in locale of the label. + * + * <p>The display name can be an empty string if this {@link Category} object is constructed + * without displayName, such as when using {@link #Category(String label, float score)}. + */ + public String getDisplayName() { + return displayName; + } + + /** Gets the score of the category. */ + public float getScore() { + return score; + } + + @Override + public boolean equals(Object o) { + if (o instanceof Category) { + Category other = (Category) o; + return (other.getLabel().equals(this.label) + && other.getDisplayName().equals(this.displayName) + && other.getScore() == this.score); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(label, displayName, score); + } + + @Override + public String toString() { + return "<Category \"" + label + "\" (displayName=" + displayName + "\" (score=" + score + ")>"; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java new file mode 100644 index 00000000..840ed5fb --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/LabelUtil.java @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.label; + +import android.util.Log; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** Label operation utils. */ +public class LabelUtil { + /** + * Maps an int value tensor to a list of string labels. It takes an array of strings as the + * dictionary. Example: if the given tensor is [3, 1, 0], and given labels is ["background", + * "apple", "banana", "cherry", "date"], the result will be ["date", "banana", "apple"]. + * + * @param tensorBuffer: A tensor with index values. The values should be non-negative integers, + * and each value {@code x} will be converted to {@code labels[x + offset]}. If the tensor is + * given as a float {@link TensorBuffer}, values will be cast to integers. All values that are + * out of bound will map to empty string. + * @param labels: A list of strings, used as a dictionary to look up. The index of the array + * element will be used as the key. To get better performance, use an object that implements + * RandomAccess, such as {@link ArrayList}. + * @param offset: The offset value when look up int values in the {@code labels}. + * @return the mapped strings. The length of the list is {@link TensorBuffer#getFlatSize}. + * @throws IllegalArgumentException if {@code tensorBuffer} or {@code labels} is null. + */ + public static List<String> mapValueToLabels( + @NonNull TensorBuffer tensorBuffer, @NonNull List<String> labels, int offset) { + SupportPreconditions.checkNotNull(tensorBuffer, "Given tensor should not be null"); + SupportPreconditions.checkNotNull(labels, "Given labels should not be null"); + int[] values = tensorBuffer.getIntArray(); + Log.d("values", Arrays.toString(values)); + List<String> result = new ArrayList<>(); + for (int v : values) { + int index = v + offset; + if (index < 0 || index >= labels.size()) { + result.add(""); + } else { + result.add(labels.get(index)); + } + } + return result; + } + + // Private constructor to prevent initialization. + private LabelUtil() {} +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java new file mode 100644 index 00000000..10763a1a --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java @@ -0,0 +1,224 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.label; + +import android.content.Context; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis. + * + * <p>For example, an image classification model may have an output tensor with shape as {1, 10}, + * where 1 is the batch size and 10 is the number of categories. In fact, on the 2nd axis, we could + * label each sub-tensor with the name or description of each corresponding category. {@link + * TensorLabel} could help converting the plain Tensor in {@link TensorBuffer} into a map from + * predefined labels to sub-tensors. In this case, if provided 10 labels for the 2nd axis, {@link + * TensorLabel} could convert the original {1, 10} Tensor to a 10 element map, each value of which + * is Tensor in shape {} (scalar). Usage example: + * + * <pre> + * TensorBuffer outputTensor = ...; + * {@literal List<String>} labels = FileUtil.loadLabels(context, labelFilePath); + * // labels the first axis with size greater than one + * TensorLabel labeled = new TensorLabel(labels, outputTensor); + * // If each sub-tensor has effectively size 1, we can directly get a float value + * {@literal Map<String, Float>} probabilities = labeled.getMapWithFloatValue(); + * // Or get sub-tensors, when each sub-tensor has elements more than 1 + * {@literal Map<String, TensorBuffer>} subTensors = labeled.getMapWithTensorBuffer(); + * </pre> + * + * <p>Note: currently we only support tensor-to-map conversion for the first label with size greater + * than 1. + * + * @see org.tensorflow.lite.support.common.FileUtil#loadLabels(Context, String) to load labels from + * a label file (plain text file whose each line is a label) in assets simply. + */ +public class TensorLabel { + private final Map<Integer, List<String>> axisLabels; + private final TensorBuffer tensorBuffer; + private final int[] shape; + + /** + * Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors. + * + * @param axisLabels A map, whose key is axis id (starting from 0) and value is corresponding + * labels. Note: The size of labels should be same with the size of the tensor on that axis. + * @param tensorBuffer The TensorBuffer to be labeled. + * @throws NullPointerException if {@code axisLabels} or {@code tensorBuffer} is null, or any + * value in {@code axisLabels} is null. + * @throws IllegalArgumentException if any key in {@code axisLabels} is out of range (compared to + * the shape of {@code tensorBuffer}, or any value (labels) has different size with the {@code + * tensorBuffer} on the given dimension. + */ + public TensorLabel( + @NonNull Map<Integer, List<String>> axisLabels, @NonNull TensorBuffer tensorBuffer) { + SupportPreconditions.checkNotNull(axisLabels, "Axis labels cannot be null."); + SupportPreconditions.checkNotNull(tensorBuffer, "Tensor Buffer cannot be null."); + this.axisLabels = axisLabels; + this.tensorBuffer = tensorBuffer; + this.shape = tensorBuffer.getShape(); + for (Map.Entry<Integer, List<String>> entry : axisLabels.entrySet()) { + int axis = entry.getKey(); + SupportPreconditions.checkArgument( + axis >= 0 && axis < shape.length, "Invalid axis id: " + axis); + SupportPreconditions.checkNotNull(entry.getValue(), "Label list is null on axis " + axis); + SupportPreconditions.checkArgument( + shape[axis] == entry.getValue().size(), + "Label number " + entry.getValue().size() + " mismatch the shape on axis " + axis); + } + } + + /** + * Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors. + * + * <p>Note: The labels are applied on the first axis whose size is larger than 1. For example, if + * the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from + * 0), and size of {@code axisLabels} should be 10 as well. + * + * @param axisLabels A list of labels, whose size should be same with the size of the tensor on + * the to-be-labeled axis. + * @param tensorBuffer The TensorBuffer to be labeled. + */ + public TensorLabel(@NonNull List<String> axisLabels, @NonNull TensorBuffer tensorBuffer) { + this(makeMap(getFirstAxisWithSizeGreaterThanOne(tensorBuffer), axisLabels), tensorBuffer); + } + + /** + * Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the + * mapping on the first axis with size greater than 1 currently. + */ + @NonNull + public Map<String, TensorBuffer> getMapWithTensorBuffer() { + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); + + Map<String, TensorBuffer> labelToTensorMap = new LinkedHashMap<>(); + SupportPreconditions.checkArgument( + axisLabels.containsKey(labeledAxis), + "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis."); + List<String> labels = axisLabels.get(labeledAxis); + + DataType dataType = tensorBuffer.getDataType(); + int typeSize = tensorBuffer.getTypeSize(); + int flatSize = tensorBuffer.getFlatSize(); + + // Gets the underlying bytes that could be used to generate the sub-array later. + ByteBuffer byteBuffer = tensorBuffer.getBuffer(); + byteBuffer.rewind(); + + // Note: computation below is only correct when labeledAxis is the first axis with size greater + // than 1. + int subArrayLength = flatSize / shape[labeledAxis] * typeSize; + int i = 0; + SupportPreconditions.checkNotNull(labels, "Label list should never be null"); + for (String label : labels) { + // Gets the corresponding TensorBuffer. + byteBuffer.position(i * subArrayLength); + ByteBuffer subBuffer = byteBuffer.slice(); + // ByteBuffer.slice doesn't keep order. Modify it to align with the original one. + subBuffer.order(byteBuffer.order()).limit(subArrayLength); + TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType); + labelBuffer.loadBuffer(subBuffer, Arrays.copyOfRange(shape, labeledAxis + 1, shape.length)); + labelToTensorMap.put(label, labelBuffer); + i += 1; + } + return labelToTensorMap; + } + + /** + * Gets a map that maps label to float. Only allow the mapping on the first axis with size greater + * than 1, and the axis should be effectively the last axis (which means every sub tensor + * specified by this axis should have a flat size of 1). + * + * <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result. + * + * @throws IllegalStateException if size of a sub tensor on each label is not 1. + */ + @NonNull + public Map<String, Float> getMapWithFloatValue() { + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); + SupportPreconditions.checkState( + labeledAxis == shape.length - 1, + "get a <String, Scalar> map is only valid when the only labeled axis is the last one."); + List<String> labels = axisLabels.get(labeledAxis); + float[] data = tensorBuffer.getFloatArray(); + SupportPreconditions.checkState(labels.size() == data.length); + Map<String, Float> result = new LinkedHashMap<>(); + int i = 0; + for (String label : labels) { + result.put(label, data[i]); + i += 1; + } + return result; + } + + /** + * Gets a list of {@link Category} from the {@link TensorLabel} object. + * + * <p>The axis of label should be effectively the last axis (which means every sub tensor + * specified by this axis should have a flat size of 1), so that each labelled sub tensor could be + * converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}} + * and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}. + * + * <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as + * the result. + * + * @throws IllegalStateException if size of a sub tensor on each label is not 1. + */ + @NonNull + public List<Category> getCategoryList() { + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); + SupportPreconditions.checkState( + labeledAxis == shape.length - 1, + "get a Category list is only valid when the only labeled axis is the last one."); + List<String> labels = axisLabels.get(labeledAxis); + float[] data = tensorBuffer.getFloatArray(); + SupportPreconditions.checkState(labels.size() == data.length); + List<Category> result = new ArrayList<>(); + int i = 0; + for (String label : labels) { + result.add(new Category(label, data[i])); + i += 1; + } + return result; + } + + private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) { + int[] shape = tensorBuffer.getShape(); + for (int i = 0; i < shape.length; i++) { + if (shape[i] > 1) { + return i; + } + } + throw new IllegalArgumentException( + "Cannot find an axis to label. A valid axis to label should have size larger than 1."); + } + + // Helper function to wrap the List<String> to a one-entry map. + private static Map<Integer, List<String>> makeMap(int axis, List<String> labels) { + Map<Integer, List<String>> map = new LinkedHashMap<>(); + map.put(axis, labels); + return map; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java new file mode 100644 index 00000000..c2de8c0b --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java @@ -0,0 +1,74 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.label.ops; + +import android.content.Context; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.common.SupportPreconditions; +import org.tensorflow.lite.support.label.TensorLabel; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Labels TensorBuffer with axisLabels for outputs. + * + * <p>Apply on a {@code TensorBuffer} to get a {@code TensorLabel} that could output a Map, which is + * a pair of the label name and the corresponding TensorBuffer value. + */ +public class LabelAxisOp { + // Axis and its corresponding label names. + private final Map<Integer, List<String>> axisLabels; + + protected LabelAxisOp(Builder builder) { + axisLabels = builder.axisLabels; + } + + public TensorLabel apply(@NonNull TensorBuffer buffer) { + SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null."); + return new TensorLabel(axisLabels, buffer); + } + + /** The inner builder class to build a LabelTensor Operator. */ + public static class Builder { + private final Map<Integer, List<String>> axisLabels; + + protected Builder() { + axisLabels = new HashMap<>(); + } + + public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath) + throws IOException { + SupportPreconditions.checkNotNull(context, "Context cannot be null."); + SupportPreconditions.checkNotNull(filePath, "File path cannot be null."); + List<String> labels = FileUtil.loadLabels(context, filePath); + axisLabels.put(axis, labels); + return this; + } + + public Builder addAxisLabel(int axis, @NonNull List<String> labels) { + axisLabels.put(axis, labels); + return this; + } + + public LabelAxisOp build() { + return new LabelAxisOp(this); + } + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java new file mode 100644 index 00000000..9cfcf923 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.model; + +import android.util.Log; +import java.io.Closeable; +import java.io.IOException; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.tensorflow.lite.Delegate; + +/** + * Helper class to create and call necessary methods of {@code GpuDelegate} which is not a strict + * dependency. + */ +class GpuDelegateProxy implements Delegate, Closeable { + + private static final String TAG = "GpuDelegateProxy"; + + private final Delegate proxiedDelegate; + private final Closeable proxiedCloseable; + + @Nullable + public static GpuDelegateProxy maybeNewInstance() { + try { + Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate"); + Object instance = clazz.getDeclaredConstructor().newInstance(); + return new GpuDelegateProxy(instance); + } catch (ReflectiveOperationException e) { + Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e); + return null; + } + } + + /** Calls {@code close()} method of the delegate. */ + @Override + public void close() { + try { + proxiedCloseable.close(); + } catch (IOException e) { + // Should not trigger, because GpuDelegate#close never throws. The catch is required because + // of Closeable#close. + Log.e(TAG, "Failed to close the GpuDelegate.", e); + } + } + + /** Calls {@code getNativeHandle()} method of the delegate. */ + @Override + public long getNativeHandle() { + return proxiedDelegate.getNativeHandle(); + } + + private GpuDelegateProxy(Object instance) { + this.proxiedCloseable = (Closeable) instance; + this.proxiedDelegate = (Delegate) instance; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java new file mode 100644 index 00000000..8062d68d --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/model/Model.java @@ -0,0 +1,285 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.model; + +import android.content.Context; +import java.io.IOException; +import java.nio.MappedByteBuffer; +import java.util.Map; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.Tensor; +import org.tensorflow.lite.support.common.FileUtil; +import org.tensorflow.lite.support.common.SupportPreconditions; + +/** + * The wrapper class for a TFLite model and a TFLite interpreter. + * + * <p>Note: A {@link Model} can only holds 1 TFLite model at a time, and always holds a TFLite + * interpreter instance to run it. + */ +public class Model { + + /** The runtime device type used for executing classification. */ + public enum Device { + CPU, + NNAPI, + GPU + } + + /** + * Options for running the model. Configurable parameters includes: + * + * <ul> + * <li>{@code device} {@link Builder#setDevice(Device)} specifies the hardware to run the model. + * The default value is {@link Device#CPU}. + * <li>{@code numThreads} {@link Builder#setNumThreads(int)} specifies the number of threads + * used by TFLite inference. It's only effective when device is set to {@link Device#CPU} + * and default value is 1. + * </ul> + */ + public static class Options { + private final Device device; + private final int numThreads; + + /** Builder of {@link Options}. See its doc for details. */ + public static class Builder { + private Device device = Device.CPU; + private int numThreads = 1; + + public Builder setDevice(Device device) { + this.device = device; + return this; + } + + public Builder setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + + public Options build() { + return new Options(this); + } + } + + private Options(Builder builder) { + device = builder.device; + numThreads = builder.numThreads; + } + } + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + private final Interpreter interpreter; + + /** Path to tflite model file in asset folder. */ + private final String modelPath; + + /** The memory-mapped model data. */ + private final MappedByteBuffer byteModel; + + private final GpuDelegateProxy gpuDelegateProxy; + + /** + * Builder for {@link Model}. + * + * @deprecated Please use {@link Model#createModel(Context, String, Options)}. + */ + @Deprecated + public static class Builder { + private Device device = Device.CPU; + private int numThreads = 1; + private final String modelPath; + private final MappedByteBuffer byteModel; + + /** + * Creates a builder which loads tflite model from asset folder using memory-mapped files. + * + * @param context: Application context to access assets. + * @param modelPath: Asset path of the model (.tflite file). + * @throws IOException if an I/O error occurs when loading the tflite model. + */ + @NonNull + public Builder(@NonNull Context context, @NonNull String modelPath) throws IOException { + this.modelPath = modelPath; + byteModel = FileUtil.loadMappedFile(context, modelPath); + } + + /** Sets running device. By default, TFLite will run on CPU. */ + @NonNull + public Builder setDevice(Device device) { + this.device = device; + return this; + } + + /** Sets number of threads. By default it's 1. */ + @NonNull + public Builder setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + + // Note: The implementation is copied from `Model#createModel`. As the builder is going to be + // deprecated, this function is also to be removed. + @NonNull + public Model build() { + Options options = new Options.Builder().setNumThreads(numThreads).setDevice(device).build(); + return createModel(byteModel, modelPath, options); + } + } + + /** + * Loads a model from assets and initialize TFLite interpreter. + * + * <p>The default options are: (1) CPU device; (2) one thread. + * + * @param context The App Context. + * @param modelPath The path of the model file. + * @throws IOException if any exception occurs when open the model file. + */ + public static Model createModel(@NonNull Context context, @NonNull String modelPath) + throws IOException { + return createModel(context, modelPath, new Options.Builder().build()); + } + + /** + * Loads a model from assets and initialize TFLite interpreter with given options. + * + * @see Options for details. + * @param context The App Context. + * @param modelPath The path of the model file. + * @param options The options for running the model. + * @throws IOException if any exception occurs when open the model file. + */ + public static Model createModel( + @NonNull Context context, @NonNull String modelPath, @NonNull Options options) + throws IOException { + SupportPreconditions.checkNotEmpty( + modelPath, "Model path in the asset folder cannot be empty."); + MappedByteBuffer byteModel = FileUtil.loadMappedFile(context, modelPath); + return createModel(byteModel, modelPath, options); + } + + /** + * Creates a model with loaded {@link MappedByteBuffer}. + * + * @see Options for details. + * @param byteModel The loaded TFLite model. + * @param modelPath The original path of the model. It can be fetched later by {@link + * Model#getPath()}. + * @param options The options for running the model. + * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but + * "tensorflow-lite-gpu" is not linked to the project. + */ + public static Model createModel( + @NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) { + Interpreter.Options interpreterOptions = new Interpreter.Options(); + GpuDelegateProxy gpuDelegateProxy = null; + switch (options.device) { + case NNAPI: + interpreterOptions.setUseNNAPI(true); + break; + case GPU: + gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance(); + SupportPreconditions.checkArgument( + gpuDelegateProxy != null, + "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?"); + interpreterOptions.addDelegate(gpuDelegateProxy); + break; + case CPU: + break; + } + interpreterOptions.setNumThreads(options.numThreads); + Interpreter interpreter = new Interpreter(byteModel, interpreterOptions); + return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy); + } + + /** Returns the memory-mapped model data. */ + @NonNull + public MappedByteBuffer getData() { + return byteModel; + } + + /** Returns the path of the model file stored in Assets. */ + @NonNull + public String getPath() { + return modelPath; + } + + /** + * Gets the Tensor associated with the provdied input index. + * + * @throws IllegalStateException if the interpreter is closed. + */ + public Tensor getInputTensor(int inputIndex) { + return interpreter.getInputTensor(inputIndex); + } + + /** + * Gets the Tensor associated with the provdied output index. + * + * @throws IllegalStateException if the interpreter is closed. + */ + public Tensor getOutputTensor(int outputIndex) { + return interpreter.getOutputTensor(outputIndex); + } + + /** + * Returns the output shape. Useful if output shape is only determined when graph is created. + * + * @throws IllegalStateException if the interpreter is closed. + */ + public int[] getOutputTensorShape(int outputIndex) { + return interpreter.getOutputTensor(outputIndex).shape(); + } + + /** + * Runs model inference on multiple inputs, and returns multiple outputs. + * + * @param inputs an array of input data. The inputs should be in the same order as inputs of the + * model. Each input can be an array or multidimensional array, or a {@link + * java.nio.ByteBuffer} of primitive types including int, float, long, and byte. {@link + * java.nio.ByteBuffer} is the preferred way to pass large input data, whereas string types + * require using the (multi-dimensional) array input path. When {@link java.nio.ByteBuffer} is + * used, its content should remain unchanged until model inference is done. + * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link + * java.nio.ByteBuffer}s of primitive types including int, float, long, and byte. It only + * needs to keep entries for the outputs to be used. + */ + public void run(@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) { + interpreter.runForMultipleInputsOutputs(inputs, outputs); + } + + public void close() { + if (interpreter != null) { + interpreter.close(); + } + if (gpuDelegateProxy != null) { + gpuDelegateProxy.close(); + } + } + + private Model( + @NonNull String modelPath, + @NonNull MappedByteBuffer byteModel, + @NonNull Interpreter interpreter, + @Nullable GpuDelegateProxy gpuDelegateProxy) { + this.modelPath = modelPath; + this.byteModel = byteModel; + this.interpreter = interpreter; + this.gpuDelegateProxy = gpuDelegateProxy; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java new file mode 100644 index 00000000..446d0ea5 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBuffer.java @@ -0,0 +1,430 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.tensorbuffer; + +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.common.SupportPreconditions.checkNotNull; +import static org.tensorflow.lite.support.common.SupportPreconditions.checkState; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.DataType; + +/** Represents the data buffer for either a model's input or its output. */ +public abstract class TensorBuffer { + /** Where the data is stored. */ + protected ByteBuffer buffer; + + /** Shape of the tensor stored in this buffer. */ + protected int[] shape; + + /** Number of elements in the buffer. It will be changed to a proper value in the constructor. */ + protected int flatSize = -1; + + /** + * Indicator of whether this buffer is dynamic or fixed-size. Fixed-size buffers will have + * pre-allocated memory and fixed size. While the size of dynamic buffers can be changed. + */ + protected final boolean isDynamic; + + /** + * Creates a {@link TensorBuffer} with specified {@code shape} and {@link DataType}. Here are some + * examples: + * + * <pre> + * Creating a float TensorBuffer with shape {2, 3}: + * int[] shape = new int[] {2, 3}; + * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.FLOAT32); + * </pre> + * + * <pre> + * Creating an uint8 TensorBuffer of a scalar: + * int[] shape = new int[] {}; + * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); + * </pre> + * + * <pre> + * Creating an empty uint8 TensorBuffer: + * int[] shape = new int[] {0}; + * TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(shape, DataType.UINT8); + * </pre> + * + * <p>The size of a fixed-size TensorBuffer cannot be changed once it is created. + * + * @param shape The shape of the {@link TensorBuffer} to be created. + * @param dataType The dataType of the {@link TensorBuffer} to be created. + * @throws NullPointerException if {@code shape} is null. + * @throws IllegalArgumentException if {@code shape} has non-positive elements. + */ + @NonNull + public static TensorBuffer createFixedSize(@NonNull int[] shape, DataType dataType) { + switch (dataType) { + case FLOAT32: + return new TensorBufferFloat(shape); + case UINT8: + return new TensorBufferUint8(shape); + default: + throw new AssertionError("TensorBuffer does not support data type: " + dataType); + } + } + + /** + * Creates an empty dynamic {@link TensorBuffer} with specified {@link DataType}. The shape of the + * created {@link TensorBuffer} is {0}. + * + * <p>Dynamic TensorBuffers will reallocate memory when loading arrays or data buffers of + * different buffer sizes. + * + * @param dataType The dataType of the {@link TensorBuffer} to be created. + */ + @NonNull + public static TensorBuffer createDynamic(DataType dataType) { + switch (dataType) { + case FLOAT32: + return new TensorBufferFloat(); + case UINT8: + return new TensorBufferUint8(); + default: + throw new AssertionError("TensorBuffer does not support data type: " + dataType); + } + } + + /** + * Creates a {@link TensorBuffer} deep-copying data from another, with specified {@link DataType}. + * + * @param buffer the source {@link TensorBuffer} to copy from. + * @param dataType the expected {@link DataType} of newly created {@link TensorBuffer}. + * @throws NullPointerException if {@code buffer} is null. + */ + @NonNull + public static TensorBuffer createFrom(@NonNull TensorBuffer buffer, DataType dataType) { + checkNotNull(buffer, "Cannot create a buffer from null"); + TensorBuffer result; + if (buffer.isDynamic()) { + result = createDynamic(dataType); + } else { + result = createFixedSize(buffer.shape, dataType); + } + // The only scenario we need float array is FLOAT32->FLOAT32, or we can always use INT as + // intermediate container. + // The assumption is not true when we support other data types. + if (buffer.getDataType() == DataType.FLOAT32 && dataType == DataType.FLOAT32) { + float[] data = buffer.getFloatArray(); + result.loadArray(data, buffer.shape); + } else { + int[] data = buffer.getIntArray(); + result.loadArray(data, buffer.shape); + } + return result; + } + + /** Returns the data buffer. */ + @NonNull + public ByteBuffer getBuffer() { + return buffer; + } + + /** + * Gets the {@link TensorBuffer#flatSize} of the buffer. + * + * @throws IllegalStateException if the underlying data is corrupted + */ + public int getFlatSize() { + assertShapeIsCorect(); + return flatSize; + } + + /** + * Gets the current shape. (returning a copy here to avoid unexpected modification.) + * + * @throws IllegalStateException if the underlying data is corrupted + */ + @NonNull + public int[] getShape() { + assertShapeIsCorect(); + return Arrays.copyOf(shape, shape.length); + } + + /** Returns the data type of this buffer. */ + public abstract DataType getDataType(); + + /** + * Returns a float array of the values stored in this buffer. If the buffer is of different types + * than float, the values will be converted into float. For example, values in {@link + * TensorBufferUint8} will be converted from uint8 to float. + */ + @NonNull + public abstract float[] getFloatArray(); + + /** + * Returns a float value at a given index. If the buffer is of different types than float, the + * value will be converted into float. For example, when reading a value from {@link + * TensorBufferUint8}, the value will be first read out as uint8, and then will be converted from + * uint8 to float. + * + * <pre> + * For example, a TensorBuffer with shape {2, 3} that represents the following array, + * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. + * + * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by: + * float v = tensorBuffer.getFloatValue(3); + * </pre> + * + * @param absIndex The absolute index of the value to be read. + */ + public abstract float getFloatValue(int absIndex); + + /** + * Returns an int array of the values stored in this buffer. If the buffer is of different type + * than int, the values will be converted into int, and loss of precision may apply. For example, + * getting an int array from a {@link TensorBufferFloat} with values {400.32f, 23.04f}, the output + * is {400, 23}. + */ + @NonNull + public abstract int[] getIntArray(); + + /** + * Returns an int value at a given index. If the buffer is of different types than int, the value + * will be converted into int. For example, when reading a value from {@link TensorBufferFloat}, + * the value will be first read out as float, and then will be converted from float to int. Loss + * of precision may apply. + * + * <pre> + * For example, a TensorBuffer with shape {2, 3} that represents the following array, + * [[0.0f, 1.0f, 2.0f], [3.0f, 4.0f, 5.0f]]. + * + * The fourth element (whose value is 3.0f) in the TensorBuffer can be retrived by: + * int v = tensorBuffer.getIntValue(3); + * Note that v is converted from 3.0f to 3 as a result of type conversion. + * </pre> + * + * @param absIndex The absolute index of the value to be read. + */ + public abstract int getIntValue(int absIndex); + + /** + * Returns the number of bytes of a single element in the array. For example, a float buffer will + * return 4, and a byte buffer will return 1. + */ + public abstract int getTypeSize(); + + /** Returns if the {@link TensorBuffer} is dynamic sized (could resize arbitrarily). */ + public boolean isDynamic() { + return isDynamic; + } + + /** + * Loads an int array into this buffer with specific shape. If the buffer is of different types + * than int, the values will be converted into the buffer's type before being loaded into the + * buffer, and loss of precision may apply. For example, loading an int array with values {400, + * -23} into a {@link TensorBufferUint8} , the values will be clamped to [0, 255] and then be + * casted to uint8 by {255, 0}. + * + * @param src The source array to be loaded. + * @param shape Shape of the tensor that {@code src} represents. + * @throws NullPointerException if {@code src} is null. + * @throws NullPointerException if {@code shape} is null. + * @throws IllegalArgumentException if the size of the array to be loaded does not match the + * specified shape. + */ + public abstract void loadArray(@NonNull int[] src, @NonNull int[] shape); + + /** + * Loads an int array into this buffer. If the buffer is of different types than int, the values + * will be converted into the buffer's type before being loaded into the buffer, and loss of + * precision may apply. For example, loading an int array with values {400, -23} into a {@link + * TensorBufferUint8} , the values will be clamped to [0, 255] and then be casted to uint8 by + * {255, 0}. + * + * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both + * fixed-size and dynamic {@link TensorBuffer}. + * + * @param src The source array to be loaded. + */ + public void loadArray(@NonNull int[] src) { + loadArray(src, shape); + } + + /** + * Loads a float array into this buffer with specific shape. If the buffer is of different types + * than float, the values will be converted into the buffer's type before being loaded into the + * buffer, and loss of precision may apply. For example, loading a float array into a {@link + * TensorBufferUint8} with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and + * then be casted to uint8 by {255, 0}. + * + * @param src The source array to be loaded. + * @param shape Shape of the tensor that {@code src} represents. + * @throws NullPointerException if {@code src} is null. + * @throws NullPointerException if {@code shape} is null. + * @throws IllegalArgumentException if the size of the array to be loaded does not match the + * specified shape. + */ + public abstract void loadArray(@NonNull float[] src, @NonNull int[] shape); + + /** + * Loads a float array into this buffer. If the buffer is of different types than float, the + * values will be converted into the buffer's type before being loaded into the buffer, and loss + * of precision may apply. For example, loading a float array into a {@link TensorBufferUint8} + * with values {400.32f, -23.04f}, the values will be clamped to [0, 255] and then be casted to + * uint8 by {255, 0}. + * + * <p>Size of {@code src} should always match the flat size of this {@link TensorBuffer}, for both + * fixed-size and dynamic {@link TensorBuffer}. + * + * @param src The source array to be loaded. + */ + public void loadArray(@NonNull float[] src) { + loadArray(src, shape); + } + + /** + * Loads a byte buffer into this {@link TensorBuffer} with specific shape. + * + * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for + * performance concern, but if modification is necessary, please make a copy. + * + * @param buffer The byte buffer to load. + * @throws NullPointerException if {@code buffer} is null. + * @throws IllegalArgumentException if the size of {@code buffer} and {@code typeSize} do not + * match or the size of {@code buffer} and {@code flatSize} do not match. + */ + public void loadBuffer(@NonNull ByteBuffer buffer, @NonNull int[] shape) { + checkNotNull(buffer, "Byte buffer cannot be null."); + int flatSize = computeFlatSize(shape); + checkArgument( + (buffer.limit() == getTypeSize() * flatSize), + "The size of byte buffer and the shape do not match."); + + resize(shape); + buffer.rewind(); + this.buffer = buffer; + } + + /** + * Loads a byte buffer into this {@link TensorBuffer}. Buffer size must match the flat size of + * this {@link TensorBuffer}. + * + * <p>Important: The loaded buffer is a reference. DO NOT MODIFY. We don't create a copy here for + * performance concern, but if modification is necessary, please make a copy. + * + * @param buffer The byte buffer to load. + */ + public void loadBuffer(@NonNull ByteBuffer buffer) { + loadBuffer(buffer, shape); + } + + /** + * Constructs a fixed size {@link TensorBuffer} with specified {@code shape}. + * + * @throws NullPointerException if {@code shape} is null. + * @throws IllegalArgumentException if {@code shape} has non-positive elements. + */ + protected TensorBuffer(@NonNull int[] shape) { + isDynamic = false; + allocateMemory(shape); + } + + /** Constructs a dynamic {@link TensorBuffer} which can be resized. */ + protected TensorBuffer() { + isDynamic = true; + // Initialize the dynamic TensorBuffer with an empty ByteBuffer. + allocateMemory(new int[] {0}); + } + + /** Calculates number of elements in the buffer. */ + protected static int computeFlatSize(@NonNull int[] shape) { + checkNotNull(shape, "Shape cannot be null."); + int prod = 1; + for (int s : shape) { + prod = prod * s; + } + return prod; + } + + /** + * For dynamic buffer, resize the memory if needed. For fixed-size buffer, check if the {@code + * shape} of src fits the buffer size. + */ + protected void resize(@NonNull int[] shape) { + if (isDynamic) { + allocateMemory(shape); + } else { + // Make sure the new shape fits the buffer size when TensorBuffer has fixed size. + checkArgument(Arrays.equals(shape, this.shape)); + this.shape = shape.clone(); + } + } + + /** + * Allocates buffer with corresponding size of the {@code shape}. If shape is an empty array, this + * {@link TensorBuffer} will be created as a scalar and its flatSize will be 1. + * + * @throws NullPointerException if {@code shape} is null. + * @throws IllegalArgumentException if {@code shape} has negative elements. + */ + private void allocateMemory(@NonNull int[] shape) { + checkNotNull(shape, "TensorBuffer shape cannot be null."); + checkArgument(isShapeValid(shape), "Values in TensorBuffer shape should be non-negative."); + + // Check if the new shape is the same as current shape. + int newFlatSize = computeFlatSize(shape); + this.shape = shape.clone(); + if (flatSize == newFlatSize) { + return; + } + + // Update to the new shape. + flatSize = newFlatSize; + buffer = ByteBuffer.allocateDirect(flatSize * getTypeSize()); + buffer.order(ByteOrder.nativeOrder()); + } + + /** + * Verifies if the shape of the {@link TensorBuffer} matched the size of the underlying {@link + * ByteBuffer}. + */ + private void assertShapeIsCorect() { + int flatSize = computeFlatSize(shape); + checkState( + (buffer.limit() == getTypeSize() * flatSize), + String.format( + "The size of underlying ByteBuffer (%d) and the shape (%s) do not match. The" + + " ByteBuffer may have been changed.", + buffer.limit(), Arrays.toString(shape))); + } + + /** + * Checks if {@code shape} meets one of following two requirements: 1. Elements in {@code shape} + * are all non-negative numbers. 2. {@code shape} is an empty array, which corresponds to scalar. + */ + private static boolean isShapeValid(@NonNull int[] shape) { + if (shape.length == 0) { + // This shape refers to a scalar. + return true; + } + + // This shape refers to a multidimensional array. + for (int s : shape) { + // All elements in shape should be non-negative. + if (s < 0) { + return false; + } + } + return true; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java new file mode 100644 index 00000000..65bbd7d0 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferFloat.java @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.tensorbuffer; + +import java.nio.FloatBuffer; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.common.SupportPreconditions; + +/** Represents data buffer with float values. */ +public final class TensorBufferFloat extends TensorBuffer { + private static final DataType DATA_TYPE = DataType.FLOAT32; + + /** + * Creates a {@link TensorBufferFloat} with specified {@code shape}. + * + * @throws NullPointerException if {@code shape} is null. + * @throws IllegalArgumentException if {@code shape} has non-positive elements. + */ + TensorBufferFloat(@NonNull int[] shape) { + super(shape); + } + + TensorBufferFloat() { + super(); + } + + @Override + public DataType getDataType() { + return DATA_TYPE; + } + + @Override + @NonNull + public float[] getFloatArray() { + buffer.rewind(); + float[] arr = new float[flatSize]; + + FloatBuffer floatBuffer = buffer.asFloatBuffer(); + floatBuffer.get(arr); + return arr; + } + + @Override + public float getFloatValue(int absIndex) { + return buffer.getFloat(absIndex << 2); + } + + @Override + @NonNull + public int[] getIntArray() { + buffer.rewind(); + float[] floatArr = new float[flatSize]; + buffer.asFloatBuffer().get(floatArr); + + int[] intArr = new int[flatSize]; + for (int i = 0; i < flatSize; i++) { + intArr[i] = (int) floatArr[i]; + } + return intArr; + } + + @Override + public int getIntValue(int absIndex) { + return (int) buffer.getFloat(absIndex << 2); + } + + @Override + public int getTypeSize() { + return DATA_TYPE.byteSize(); + } + + @Override + public void loadArray(@NonNull float[] src, @NonNull int[] shape) { + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); + SupportPreconditions.checkArgument( + src.length == computeFlatSize(shape), + "The size of the array to be loaded does not match the specified shape."); + resize(shape); + buffer.rewind(); + + FloatBuffer floatBuffer = buffer.asFloatBuffer(); + floatBuffer.put(src); + } + + @Override + public void loadArray(@NonNull int[] src, @NonNull int[] shape) { + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); + SupportPreconditions.checkArgument( + src.length == computeFlatSize(shape), + "The size of the array to be loaded does not match the specified shape."); + resize(shape); + buffer.rewind(); + + float[] floatArray = new float[src.length]; + int cnt = 0; + for (int a : src) { + floatArray[cnt++] = (float) a; + } + buffer.asFloatBuffer().put(floatArray); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java new file mode 100644 index 00000000..33641940 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/tensorbuffer/TensorBufferUint8.java @@ -0,0 +1,121 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.tensorbuffer; + +import org.checkerframework.checker.nullness.qual.NonNull; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.common.SupportPreconditions; + +/** Represents data buffer with 8-bit unsigned integer values. */ +public final class TensorBufferUint8 extends TensorBuffer { + private static final DataType DATA_TYPE = DataType.UINT8; + + /** + * Creates a {@link TensorBufferUint8} with specified {@code shape}. + * + * @throws NullPointerException if {@code shape} is null. + * @throws IllegalArgumentException if {@code shape} has non-positive elements. + */ + TensorBufferUint8(@NonNull int[] shape) { + super(shape); + } + + TensorBufferUint8() { + super(); + } + + @Override + public DataType getDataType() { + return DATA_TYPE; + } + + @Override + @NonNull + public float[] getFloatArray() { + buffer.rewind(); + byte[] byteArr = new byte[flatSize]; + buffer.get(byteArr); + + float[] floatArr = new float[flatSize]; + for (int i = 0; i < flatSize; i++) { + floatArr[i] = (float) (byteArr[i] & 0xff); + } + return floatArr; + } + + @Override + public float getFloatValue(int index) { + return (float) (buffer.get(index) & 0xff); + } + + @Override + @NonNull + public int[] getIntArray() { + buffer.rewind(); + byte[] byteArr = new byte[flatSize]; + buffer.get(byteArr); + + int[] intArr = new int[flatSize]; + for (int i = 0; i < flatSize; i++) { + intArr[i] = byteArr[i] & 0xff; + } + return intArr; + } + + @Override + public int getIntValue(int index) { + return buffer.get(index) & 0xff; + } + + @Override + public int getTypeSize() { + return DATA_TYPE.byteSize(); + } + + @Override + public void loadArray(@NonNull float[] src, @NonNull int[] shape) { + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); + SupportPreconditions.checkArgument( + src.length == computeFlatSize(shape), + "The size of the array to be loaded does not match the specified shape."); + resize(shape); + buffer.rewind(); + + byte[] byteArr = new byte[src.length]; + int cnt = 0; + for (float a : src) { + byteArr[cnt++] = (byte) Math.max(Math.min(a, 255.0), 0.0); + } + buffer.put(byteArr); + } + + @Override + public void loadArray(@NonNull int[] src, @NonNull int[] shape) { + SupportPreconditions.checkNotNull(src, "The array to be loaded cannot be null."); + SupportPreconditions.checkArgument( + src.length == computeFlatSize(shape), + "The size of the array to be loaded does not match the specified shape."); + resize(shape); + buffer.rewind(); + + byte[] byteArr = new byte[src.length]; + int cnt = 0; + for (float a : src) { + byteArr[cnt++] = (byte) Math.max(Math.min(a, 255), 0); + } + buffer.put(byteArr); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD new file mode 100644 index 00000000..f82b8009 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BUILD @@ -0,0 +1,22 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +android_library( + name = "base-task-api", + srcs = glob(["**/*.java"]), + javacopts = JAVACOPTS, + visibility = ["//visibility:public"], + deps = [ + "@com_google_auto_value", + ], +) + +alias( + name = "base_task_api", + actual = ":base-task-api", +) diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java new file mode 100644 index 00000000..b3fe9def --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/BaseTaskApi.java @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.core; + +import android.util.Log; +import java.io.Closeable; + +/** + * Base class for Task API, provides shared logic to load/unload native libs to its C++ counterpart. + */ +public abstract class BaseTaskApi implements Closeable { + private static final String TAG = BaseTaskApi.class.getSimpleName(); + + /** + * Represents a pointer to the corresponding C++ task_api object. The nativeHandle pointer is + * initialized from subclasses and must be released by calling {@link #deinit} after it is no + * longer needed. + */ + private final long nativeHandle; + + /** Indicates whether the {@link #nativeHandle} pointer has been released yet. */ + private boolean closed; + + /** + * Constructor to initialize the JNI with a pointer from C++. + * + * @param nativeHandle a pointer referencing memory allocated in C++. + */ + protected BaseTaskApi(long nativeHandle) { + if (nativeHandle == TaskJniUtils.INVALID_POINTER) { + throw new IllegalArgumentException("Failed to load C++ pointer from JNI"); + } + this.nativeHandle = nativeHandle; + } + + public boolean isClosed() { + return closed; + } + + /** Release the memory allocated from C++ and deregister the library from the static holder. */ + @Override + public synchronized void close() { + if (closed) { + return; + } + deinit(nativeHandle); + closed = true; + } + + public long getNativeHandle() { + return nativeHandle; + } + + protected void checkNotClosed() { + if (isClosed()) { + throw new IllegalStateException("Internal error: The task lib has already been closed."); + } + } + + @Override + protected void finalize() throws Throwable { + try { + if (!closed) { + Log.w(TAG, "Closing an already closed native lib"); + close(); + } + } finally { + super.finalize(); + } + } + + /** + * Releases memory pointed by the pointer in the native layer. + * + * @param nativeHandle pointer to memory allocated + */ + protected abstract void deinit(long nativeHandle); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java new file mode 100644 index 00000000..f5c52a03 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java @@ -0,0 +1,165 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.core; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.util.Log; +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; + +/** JNI utils for Task API. */ +public class TaskJniUtils { + public static final long INVALID_POINTER = 0; + private static final String TAG = TaskJniUtils.class.getSimpleName(); + /** Syntax sugar to get nativeHandle from empty param list. */ + public interface EmptyHandleProvider { + long createHandle(); + } + + /** Syntax sugar to get nativeHandle from an array of {@link ByteBuffer}s. */ + public interface MultipleBuffersHandleProvider { + long createHandle(ByteBuffer... buffers); + } + + /** Syntax sugar to get nativeHandle from file descriptor and options. */ + public interface FdAndOptionsHandleProvider<T> { + long createHandle( + int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset, T options); + } + + /** + * Initializes the JNI and returns C++ handle with file descriptor and options for task API. + * + * @param context the Android app context + * @param provider provider to get C++ handle, usually returned from native call + * @param libName name of C++ lib to be loaded + * @param filePath path of the file to be loaded + * @param options options to set up the task API, used by the provider + * @return C++ handle as long + * @throws IOException If model file fails to load. + */ + public static <T> long createHandleFromFdAndOptions( + Context context, + final FdAndOptionsHandleProvider<T> provider, + String libName, + String filePath, + final T options) + throws IOException { + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(filePath)) { + return createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return provider.createHandle( + /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(), + /*fileDescriptorLength=*/ assetFileDescriptor.getLength(), + /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(), + options); + } + }, + libName); + } + } + + /** + * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes + * {@link EmptyHandleProvider#createHandle()}. + * + * @param provider provider to get C++ handle, usually returned from native call + * @return C++ handle as long + */ + public static long createHandleFromLibrary(EmptyHandleProvider provider, String libName) { + tryLoadLibrary(libName); + try { + return provider.createHandle(); + } catch (Exception e) { + String errorMessage = "Error getting native address of native library: " + libName; + Log.e(TAG, errorMessage, e); + throw new IllegalStateException(errorMessage, e); + } + } + + /** + * Initializes the JNI and returns C++ handle by first loading the C++ library and then invokes + * {@link MultipleBuffersHandleProvider#createHandle(ByteBuffer...)}. + * + * @param context app context + * @param provider provider to get C++ pointer, usually returned from native call + * @param libName name of C++ lib to load + * @param filePaths file paths to load + * @return C++ pointer as long + * @throws IOException If model file fails to load. + */ + public static long createHandleWithMultipleAssetFilesFromLibrary( + Context context, + final MultipleBuffersHandleProvider provider, + String libName, + String... filePaths) + throws IOException { + final MappedByteBuffer[] buffers = new MappedByteBuffer[filePaths.length]; + for (int i = 0; i < filePaths.length; i++) { + buffers[i] = loadMappedFile(context, filePaths[i]); + } + return createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return provider.createHandle(buffers); + } + }, + libName); + } + + /** + * Loads a file from the asset folder through memory mapping. + * + * @param context Application context to access assets. + * @param filePath Asset path of the file. + * @return the loaded memory mapped file. + * @throws IOException If model file fails to load. + */ + public static MappedByteBuffer loadMappedFile(Context context, String filePath) + throws IOException { + try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + } + + private TaskJniUtils() {} + + /** + * Try load a native library, if it's already loaded return directly. + * + * @param libName name of the lib + */ + static void tryLoadLibrary(String libName) { + try { + System.loadLibrary(libName); + } catch (UnsatisfiedLinkError e) { + String errorMessage = "Error loading native library: " + libName; + Log.e(TAG, errorMessage, e); + throw new UnsatisfiedLinkError(errorMessage); + } + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java new file mode 100644 index 00000000..0236f2ce --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/vision/ImageProcessingOptions.java @@ -0,0 +1,117 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.core.vision; + +import android.graphics.Rect; +import com.google.auto.value.AutoValue; + +/** + * Options to configure the image processing pipeline, which operates before inference. + * + * <p>The Task Library Vision API performs image preprocessing on the input image over the region of + * interest, so that it fits model requirements (e.g. upright 224x224 RGB) and populate the + * corresponding input tensor. This is performed by (in this order): + * + * <ul> + * <li>cropping the frame buffer to the region of interest (which, in most cases, just covers the + * entire input image), + * <li>resizing it (with bilinear interpolation, aspect-ratio *not* preserved) to the dimensions + * of the model input tensor, + * <li>converting it to the colorspace of the input tensor (i.e. RGB, which is the only supported + * colorspace for now), + * <li>rotating it according to its {@link Orientation} so that inference is performed on an + * "upright" image. + * </ul> + * + * <p>IMPORTANT: as a consequence of cropping occurring first, the provided region of interest is + * expressed in the unrotated frame of reference coordinates system, i.e. in {@code [0, + * TensorImage.getWidth()) x [0, TensorImage.getHeight())}, which are the dimensions of the + * underlying image data before any orientation gets applied. If the region is out of these bounds, + * the inference method, such as {@link ImageClassifier#classify}, will return error. + */ +@AutoValue +public abstract class ImageProcessingOptions { + + /** + * Orientation type that follows EXIF specification. + * + * <p>The name of each enum value defines the position of the 0th row and the 0th column of the + * image content. See the <a href="http://jpegclub.org/exif_orientation.html">EXIF orientation + * documentation</a> for details. + */ + public enum Orientation { + TOP_LEFT(0), + TOP_RIGHT(1), + BOTTOM_RIGHT(2), + BOTTOM_LEFT(3), + LEFT_TOP(4), + RIGHT_TOP(5), + RIGHT_BOTTOM(6), + LEFT_BOTTOM(7); + + private final int value; + + Orientation(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + }; + + private static final Rect defaultRoi = new Rect(); + private static final Orientation DEFAULT_ORIENTATION = Orientation.TOP_LEFT; + + public abstract Rect getRoi(); + + public abstract Orientation getOrientation(); + + public static Builder builder() { + return new AutoValue_ImageProcessingOptions.Builder() + .setRoi(defaultRoi) + .setOrientation(DEFAULT_ORIENTATION); + } + + /** Builder for {@link ImageProcessingOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + + /** + * Sets the region of interest (ROI) of the image. Defaults to the entire image. + * + * <p>Cropping according to this region of interest is prepended to the pre-processing + * operations. + */ + public abstract Builder setRoi(Rect roi); + + /** + * Sets the orientation of the image. Defaults to {@link Orientation#TOP_LEFT}. + * + * <p>Rotation will be applied accordingly so that inference is performed on an "upright" image. + */ + public abstract Builder setOrientation(Orientation orientation); + + abstract Rect getRoi(); + + abstract ImageProcessingOptions autoBuild(); + + public ImageProcessingOptions build() { + setRoi(new Rect(getRoi())); // Make a defensive copy, since Rect is mutable. + return autoBuild(); + } + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml new file mode 100644 index 00000000..d4d1dbad --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/AndroidManifest.xml @@ -0,0 +1,5 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.task.text"> + <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> +</manifest> diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD new file mode 100644 index 00000000..695e1bef --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/BUILD @@ -0,0 +1,37 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["AndroidManifest.xml"]) + +android_library( + name = "task_library_text", + srcs = [ + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier:nl_classifier_src", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa:bert_question_answerer_src", + ], + javacopts = JAVACOPTS, + manifest = "AndroidManifest.xml", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "//tensorflow_lite_support/java/src/native/task/text:task_text_native", + "@com_google_auto_value", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text:task-library-text +aar_with_jni( + name = "task-library-text", + android_library = ":task_library_text", +) diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD new file mode 100644 index 00000000..a1d78d8f --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BUILD @@ -0,0 +1,79 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "nl_classifier_src", + srcs = glob(["**/*.java"]), +) + +# Java-only target, need to be used together with a native target similar to +# third_party/tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_native. +# Use this target when you want to provide a MutableOpResolver with customized +# OPs and/or a subset of BuiltInOps to reduce binary size. +android_library( + name = "nl_classifier_java", + srcs = [ + "NLClassifier.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "@com_google_auto_value", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + ], +) + +# Default target that uses BuiltInOpResolver, registers all built-in OPs. +android_library( + name = "nl_classifier", + srcs = [ + "NLClassifier.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_native", + "@com_google_auto_value", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier:nl-classifier +aar_with_jni( + name = "nl-classifier", + android_library = ":nl_classifier", +) + +# Default target that uses BuiltInOpResolver, registers all built-in OPs. +android_library( + name = "bert_nl_classifier", + srcs = [ + "BertNLClassifier.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier:bert_nl_classifier_native", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier:bert-nl-classifier +aar_with_jni( + name = "bert-nl-classifier", + android_library = ":bert_nl_classifier", +) diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java new file mode 100644 index 00000000..90bea370 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java @@ -0,0 +1,142 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.text.nlclassifier; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.util.List; +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + +/** + * Classifier API for NLClassification tasks with Bert models, categorizes string into different + * classes. The API expects a Bert based TFLite model with metadata populated. + * + * <p>The metadata should contain the following information: + * + * <ul> + * <li>1 input_process_unit for Wordpiece/Sentencepiece Tokenizer. + * <li>3 input tensors with names "ids", "mask" and "segment_ids". + * <li>1 output tensor of type float32[1, 2], with a optionally attached label file. If a label + * file is attached, the file should be a plain text file with one label per line, the number + * of labels should match the number of categories the model outputs. + * </ul> + */ +public class BertNLClassifier extends BaseTaskApi { + private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; + + /** + * Constructor to initialize the JNI with a pointer from C++. + * + * @param nativeHandle a pointer referencing memory allocated in C++. + */ + private BertNLClassifier(long nativeHandle) { + super(nativeHandle); + } + + /** + * Create {@link BertNLClassifier} from a model file with metadata. + * + * @param context Android context + * @param pathToModel Path to the classification model. + * @return {@link BertNLClassifier} instance. + * @throws IOException If model file fails to load. + */ + public static BertNLClassifier createFromFile(final Context context, final String pathToModel) + throws IOException { + return createFromBuffer(TaskJniUtils.loadMappedFile(context, pathToModel)); + } + + /** + * Create {@link BertNLClassifier} from a {@link File} object with metadata. + * + * @param modelFile The classification model {@link File} instance. + * @return {@link BertNLClassifier} instance. + * @throws IOException If model file fails to load. + */ + public static BertNLClassifier createFromFile(File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new BertNLClassifier( + TaskJniUtils.createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithFileDescriptor(descriptor.getFd()); + } + }, + BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); + } + } + + /** + * Create {@link BertNLClassifier} with a model buffer. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the model + * @return {@link BertNLClassifier} instance + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static BertNLClassifier createFromBuffer(final ByteBuffer modelBuffer) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new BertNLClassifier( + TaskJniUtils.createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer); + } + }, + BERT_NL_CLASSIFIER_NATIVE_LIBNAME)); + } + + /** + * Perform classification on a string input, returns classified {@link Category}s. + * + * @param text input text to the model. + * @return A list of Category results. + */ + public List<Category> classify(String text) { + return classifyNative(getNativeHandle(), text); + } + + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer); + + private static native long initJniWithFileDescriptor(int fd); + + private static native List<Category> classifyNative(long nativeHandle, String text); + + @Override + protected void deinit(long nativeHandle) { + deinitJni(nativeHandle); + } + + /** + * Native implementation to release memory pointed by the pointer. + * + * @param nativeHandle pointer to memory allocated + */ + private native void deinitJni(long nativeHandle); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java new file mode 100644 index 00000000..2bc20d8c --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java @@ -0,0 +1,257 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.text.nlclassifier; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.util.List; +import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; + +/** + * Classifier API for natural language classification tasks, categorizes string into different + * classes. + * + * <p>The API expects a TFLite model with the following input/output tensor: + * + * <ul> + * <li>Input tensor (kTfLiteString) + * <ul> + * <li>input of the model, accepts a string. + * </ul> + * <li>Output score tensor + * (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64/kTfLiteBool) + * <ul> + * <li>output scores for each class, if type is one of the Int types, dequantize it, if it + * is Bool type, convert the values to 0.0 and 1.0 respectively. + * <li>can have an optional associated file in metadata for labels, the file should be a + * plain text file with one label per line, the number of labels should match the number + * of categories the model outputs. Output label tensor: optional (kTfLiteString) - + * output classname for each class, should be of the same length with scores. If this + * tensor is not present, the API uses score indices as classnames. - will be ignored if + * output score tensor already has an associated label file. + * </ul> + * <li>Optional Output label tensor (kTfLiteString/kTfLiteInt32) + * <ul> + * <li>output classname for each class, should be of the same length with scores. If this + * tensor is not present, the API uses score indices as classnames. + * <li>will be ignored if output score tensor already has an associated labe file. + * </ul> + * </ul> + * + * <p>By default the API tries to find the input/output tensors with default configurations in + * {@link NLClassifierOptions}, with tensor name prioritized over tensor index. The option is + * configurable for different TFLite models. + */ +public class NLClassifier extends BaseTaskApi { + + /** Options to identify input and output tensors of the model. */ + @AutoValue + @UsedByReflection("nl_classifier_jni.cc") + public abstract static class NLClassifierOptions { + private static final int DEFAULT_INPUT_TENSOR_INDEX = 0; + private static final int DEFAULT_OUTPUT_SCORE_TENSOR_INDEX = 0; + // By default there is no output label tensor. The label file can be attached + // to the output score tensor metadata. + private static final int DEFAULT_OUTPUT_LABEL_TENSOR_INDEX = -1; + private static final String DEFAULT_INPUT_TENSOR_NAME = "INPUT"; + private static final String DEFAULT_OUTPUT_SCORE_TENSOR_NAME = "OUTPUT_SCORE"; + private static final String DEFAULT_OUTPUT_LABEL_TENSOR_NAME = "OUTPUT_LABEL"; + + @UsedByReflection("nl_classifier_jni.cc") + abstract int inputTensorIndex(); + + @UsedByReflection("nl_classifier_jni.cc") + abstract int outputScoreTensorIndex(); + + @UsedByReflection("nl_classifier_jni.cc") + abstract int outputLabelTensorIndex(); + + @UsedByReflection("nl_classifier_jni.cc") + abstract String inputTensorName(); + + @UsedByReflection("nl_classifier_jni.cc") + abstract String outputScoreTensorName(); + + @UsedByReflection("nl_classifier_jni.cc") + abstract String outputLabelTensorName(); + + public static Builder builder() { + return new AutoValue_NLClassifier_NLClassifierOptions.Builder() + .setInputTensorIndex(DEFAULT_INPUT_TENSOR_INDEX) + .setOutputScoreTensorIndex(DEFAULT_OUTPUT_SCORE_TENSOR_INDEX) + .setOutputLabelTensorIndex(DEFAULT_OUTPUT_LABEL_TENSOR_INDEX) + .setInputTensorName(DEFAULT_INPUT_TENSOR_NAME) + .setOutputScoreTensorName(DEFAULT_OUTPUT_SCORE_TENSOR_NAME) + .setOutputLabelTensorName(DEFAULT_OUTPUT_LABEL_TENSOR_NAME); + } + + /** Builder for {@link NLClassifierOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setInputTensorIndex(int value); + + public abstract Builder setOutputScoreTensorIndex(int value); + + public abstract Builder setOutputLabelTensorIndex(int value); + + public abstract Builder setInputTensorName(String value); + + public abstract Builder setOutputScoreTensorName(String value); + + public abstract Builder setOutputLabelTensorName(String value); + + public abstract NLClassifierOptions build(); + } + } + + private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; + + /** + * Constructor to initialize the JNI with a pointer from C++. + * + * @param nativeHandle a pointer referencing memory allocated in C++. + */ + protected NLClassifier(long nativeHandle) { + super(nativeHandle); + } + + /** + * Create {@link NLClassifier} from default {@link NLClassifierOptions}. + * + * @param context Android context. + * @param pathToModel Path to the classification model relative to asset dir. + * @return {@link NLClassifier} instance. + * @throws IOException If model file fails to load. + */ + public static NLClassifier createFromFile(Context context, String pathToModel) + throws IOException { + return createFromFileAndOptions(context, pathToModel, NLClassifierOptions.builder().build()); + } + + /** + * Create {@link NLClassifier} from default {@link NLClassifierOptions}. + * + * @param modelFile The classification model {@link File} instance. + * @return {@link NLClassifier} instance. + * @throws IOException If model file fails to load. + */ + public static NLClassifier createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, NLClassifierOptions.builder().build()); + } + + /** + * Create {@link NLClassifier} from {@link NLClassifierOptions}. + * + * @param context Android context + * @param pathToModel Path to the classification model relative to asset dir. + * @param options Configurations for the model. + * @return {@link NLClassifier} instance. + * @throws IOException If model file fails to load. + */ + public static NLClassifier createFromFileAndOptions( + Context context, String pathToModel, NLClassifierOptions options) throws IOException { + return createFromBufferAndOptions(TaskJniUtils.loadMappedFile(context, pathToModel), options); + } + + /** + * Create {@link NLClassifier} from {@link NLClassifierOptions}. + * + * @param modelFile The classification model {@link File} instance. + * @param options Configurations for the model. + * @return {@link NLClassifier} instance. + * @throws IOException If model file fails to load. + */ + public static NLClassifier createFromFileAndOptions( + File modelFile, final NLClassifierOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new NLClassifier( + TaskJniUtils.createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithFileDescriptor(options, descriptor.getFd()); + } + }, + NL_CLASSIFIER_NATIVE_LIBNAME)); + } + } + + /** + * Create {@link NLClassifier} with a model {@link ByteBuffer} and {@link NLClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @param options Configurations for the model + * @return {@link NLClassifier} instance + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static NLClassifier createFromBufferAndOptions( + final ByteBuffer modelBuffer, final NLClassifierOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new NLClassifier( + TaskJniUtils.createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(options, modelBuffer); + } + }, + NL_CLASSIFIER_NATIVE_LIBNAME)); + } + + /** + * Perform classification on a string input, returns classified {@link Category}s. + * + * @param text input text to the model. + * @return A list of Category results. + */ + public List<Category> classify(String text) { + return classifyNative(getNativeHandle(), text); + } + + private static native long initJniWithByteBuffer( + NLClassifierOptions options, ByteBuffer modelBuffer); + + private static native long initJniWithFileDescriptor(NLClassifierOptions options, int fd); + + private static native List<Category> classifyNative(long nativeHandle, String text); + + @Override + protected void deinit(long nativeHandle) { + deinitJni(nativeHandle); + } + + /** + * Native implementation to release memory pointed by the pointer. + * + * @param nativeHandle pointer to memory allocated + */ + private native void deinitJni(long nativeHandle); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD new file mode 100644 index 00000000..3dad1422 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD @@ -0,0 +1,33 @@ +load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +filegroup( + name = "bert_question_answerer_src", + srcs = glob(["**/*.java"]), +) + +android_library( + name = "bert_question_answerer", + srcs = glob(["*.java"]), + javacopts = JAVACOPTS, + deps = [ + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "//tensorflow_lite_support/java/src/native/task/text/qa:bert_question_answerer_native", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa:bert-question-answerer +aar_with_jni( + name = "bert-question-answerer", + android_library = ":bert_question_answerer", +) diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java new file mode 100644 index 00000000..76f562ef --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java @@ -0,0 +1,195 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.text.qa; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider; + +/** Task API for BertQA models. */ +public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer { + private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni"; + + private BertQuestionAnswerer(long nativeHandle) { + super(nativeHandle); + } + + /** + * Generic API to create the QuestionAnswerer for bert models with metadata populated. The API + * expects a Bert based TFLite model with metadata containing the following information: + * + * <ul> + * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be + * used for a <a + * href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a> + * model, Sentencepiece Tokenizer Tokenizer can be used for an <a + * href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a> + * model. + * <li>3 input tensors with names "ids", "mask" and "segment_ids". + * <li>2 output tensors with names "end_logits" and "start_logits". + * </ul> + * + * @param context android context + * @param pathToModel file path to the model with metadata. Note: The model should not be + * compressed + * @return {@link BertQuestionAnswerer} instance + * @throws IOException If model file fails to load. + */ + public static BertQuestionAnswerer createFromFile(Context context, String pathToModel) + throws IOException { + return new BertQuestionAnswerer( + TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( + context, + new MultipleBuffersHandleProvider() { + @Override + public long createHandle(ByteBuffer... buffers) { + return BertQuestionAnswerer.initJniWithModelWithMetadataByteBuffers(buffers); + } + }, + BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, + pathToModel)); + } + + /** + * Generic API to create the QuestionAnswerer for bert models with metadata populated. The API + * expects a Bert based TFLite model with metadata containing the following information: + * + * <ul> + * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be + * used for a <a + * href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a> + * model, Sentencepiece Tokenizer Tokenizer can be used for an <a + * href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a> + * model. + * <li>3 input tensors with names "ids", "mask" and "segment_ids". + * <li>2 output tensors with names "end_logits" and "start_logits". + * </ul> + * + * @param modelFile {@link File} object of the model + * @return {@link BertQuestionAnswerer} instance + * @throws IOException If model file fails to load. + */ + public static BertQuestionAnswerer createFromFile(File modelFile) + throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new BertQuestionAnswerer( + TaskJniUtils.createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithFileDescriptor(descriptor.getFd()); + } + }, + BERT_QUESTION_ANSWERER_NATIVE_LIBNAME)); + } + } + + /** + * Creates the API instance with a bert model and vocabulary file. + * + * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1 + * + * @param context android context + * @param pathToModel file path to the bert model. Note: The model should not be compressed + * @param pathToVocab file path to the vocabulary file. Note: The file should not be compressed + * @return {@link BertQuestionAnswerer} instance + * @throws IOException If model file fails to load. + */ + public static BertQuestionAnswerer createBertQuestionAnswererFromFile( + Context context, String pathToModel, String pathToVocab) throws IOException { + return new BertQuestionAnswerer( + TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( + context, + new MultipleBuffersHandleProvider() { + @Override + public long createHandle(ByteBuffer... buffers) { + return BertQuestionAnswerer.initJniWithBertByteBuffers(buffers); + } + }, + BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, + pathToModel, + pathToVocab)); + } + + /** + * Creates the API instance with an albert model and sentence piece model file. + * + * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1 + * + * @param context android context + * @param pathToModel file path to the albert model. Note: The model should not be compressed + * @param pathToSentencePieceModel file path to the sentence piece model file. Note: The model + * should not be compressed + * @return {@link BertQuestionAnswerer} instance + * @throws IOException If model file fails to load. + */ + public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile( + Context context, String pathToModel, String pathToSentencePieceModel) throws IOException { + return new BertQuestionAnswerer( + TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary( + context, + new MultipleBuffersHandleProvider() { + @Override + public long createHandle(ByteBuffer... buffers) { + return BertQuestionAnswerer.initJniWithAlbertByteBuffers(buffers); + } + }, + BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, + pathToModel, + pathToSentencePieceModel)); + } + + @Override + public List<QaAnswer> answer(String context, String question) { + checkNotClosed(); + return answerNative(getNativeHandle(), context, question); + } + + // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer. + private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers); + + // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file + // buffer. + private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers); + + // modelBuffers[0] is tflite model file buffer with metadata to specify which tokenizer to use. + private static native long initJniWithModelWithMetadataByteBuffers(ByteBuffer... modelBuffers); + + private static native long initJniWithFileDescriptor(int fd); + + private static native List<QaAnswer> answerNative( + long nativeHandle, String context, String question); + + @Override + protected void deinit(long nativeHandle) { + deinitJni(nativeHandle); + } + + /** + * Native implementation to release memory pointed by the pointer. + * + * @param nativeHandle pointer to memory allocated + */ + private native void deinitJni(long nativeHandle); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java new file mode 100644 index 00000000..4259a697 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QaAnswer.java @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.text.qa; + +import org.tensorflow.lite.annotations.UsedByReflection; + +/** + * Answers to {@link QuestionAnswerer}. Contains information about the answer and its relative + * position information to the context. + */ +public class QaAnswer { + public Pos pos; + public String text; + + @UsedByReflection("bert_question_answerer_jni.cc") + public QaAnswer(String text, Pos pos) { + this.text = text; + this.pos = pos; + } + + public QaAnswer(String text, int start, int end, float logit) { + this(text, new Pos(start, end, logit)); + } + + /** + * Position information of the answer relative to context. It is sortable in descending order + * based on logit. + */ + public static class Pos implements Comparable<Pos> { + public int start; + public int end; + public float logit; + + public Pos(int start, int end, float logit) { + this.start = start; + this.end = end; + this.logit = logit; + } + + @Override + public int compareTo(Pos other) { + return Float.compare(other.logit, this.logit); + } + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java new file mode 100644 index 00000000..8df6d379 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/QuestionAnswerer.java @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.text.qa; + +import java.util.List; + +/** API to answer questions based on context. */ +public interface QuestionAnswerer { + + /** + * Answers question based on context, and returns a list of possible {@link QaAnswer}s. Could be + * empty if no answer was found from the given context. + * + * @param context context the question bases on + * @param question question to ask + * @return a list of possible answers in {@link QaAnswer} + */ + List<QaAnswer> answer(String context, String question); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml new file mode 100644 index 00000000..e77a0734 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/AndroidManifest.xml @@ -0,0 +1,5 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.task.vision"> + <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> +</manifest> diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD new file mode 100644 index 00000000..661a7669 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/BUILD @@ -0,0 +1,41 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "AndroidManifest.xml", +]) + +android_library( + name = "task_library_vision", + srcs = [ + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier:image_classifier_src", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector:object_detector_src", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter:image_segmenter_src", + ], + # TODO(b/163039980): Use JAVACOPTS in TF. "-Xep:RemoveUnusedImports:ERROR" wierdly break the build. + javacopts = ["-source 7 -target 7"], + manifest = "AndroidManifest.xml", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "//tensorflow_lite_support/java/src/native/task/vision:task_vision_native", + "@com_google_auto_value", + "@maven//:androidx_annotation_annotation", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision:task-library-vision +aar_with_jni( + name = "task-library-vision", + android_library = ":task_library_vision", +) diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml new file mode 100644 index 00000000..ce07182e --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/AndroidManifest.xml @@ -0,0 +1,5 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.task.vision.classifier"> + <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> +</manifest> diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD new file mode 100644 index 00000000..c6a70a08 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/BUILD @@ -0,0 +1,40 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "AndroidManifest.xml", +]) + +filegroup( + name = "image_classifier_src", + srcs = glob(["**/*.java"]), +) + +android_library( + name = "image_classifier", + srcs = glob(["*.java"]), + javacopts = JAVACOPTS, + manifest = "AndroidManifest.xml", + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "//tensorflow_lite_support/java/src/native/task/vision/classifier:image_classifier_native", + "@com_google_auto_value", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier:image-classifier +aar_with_jni( + name = "image-classifier", + android_library = ":image_classifier", +) diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java new file mode 100644 index 00000000..d33f0fbb --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/Classifications.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.classifier; + +import com.google.auto.value.AutoValue; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.label.Category; + +/** + * The classification results of one head in a multihead (a.k.a. multi-output) {@link + * ImageClassifier}. A multihead {@link ImageClassifier} can perform classification for multiple + * purposes, such as a fine grained classifier to describe apparel items (e.g. color, material, + * type, etc.). + */ +@AutoValue +@UsedByReflection("image_classifier_jni.cc") +public abstract class Classifications { + + @UsedByReflection("image_classifier_jni.cc") + static Classifications create(List<Category> categories, int headIndex) { + return new AutoValue_Classifications( + Collections.unmodifiableList(new ArrayList<Category>(categories)), headIndex); + } + + // Same reason for not using ImmutableList as stated in + // {@link ImageClassifier#ImageClassifierOptions#labelAllowList}. + public abstract List<Category> getCategories(); + + public abstract int getHeadIndex(); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java new file mode 100644 index 00000000..46f6754e --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java @@ -0,0 +1,453 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.classifier; + +import android.content.Context; +import android.graphics.Rect; +import android.os.ParcelFileDescriptor; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; + +/** + * Performs classification on images. + * + * <p>The API expects a TFLite model with optional, but strongly recommended, <a + * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. + * + * <p>The API supports models with one image input tensor and one classification output tensor. To + * be more specific, here are the requirements. + * + * <ul> + * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + * <ul> + * <li>image input of size {@code [batch x height x width x channels]}. + * <li>batch inference is not supported ({@code batch} is required to be 1). + * <li>only RGB inputs are supported ({@code channels} is required to be 3). + * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached + * to the metadata for input normalization. + * </ul> + * <li>Output score tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + * <ul> + * <li>with {@code N} classes of either 2 or 4 dimensions, such as {@code [1 x N]} or {@code + * [1 x 1 x 1 x N]} + * <li>the label file is required to be packed to the metadata. See the <a + * href="https://www.tensorflow.org/lite/convert/metadata#label_output">example of + * creating metadata for an image classifier</a>. If no label files are packed, it will + * use index as label in the result. + * </ul> + * </ul> + * + * <p>An example of such model can be found on <a + * href="https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1">TensorFlow + * Hub.</a>. + */ +public final class ImageClassifier extends BaseTaskApi { + + private static final String IMAGE_CLASSIFIER_NATIVE_LIB = "task_vision_jni"; + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; + + /** + * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. + * + * @param modelPath path of the classification model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native + * code + */ + public static ImageClassifier createFromFile(Context context, String modelPath) + throws IOException { + return createFromFileAndOptions(context, modelPath, ImageClassifierOptions.builder().build()); + } + + /** + * Creates an {@link ImageClassifier} instance from the default {@link ImageClassifierOptions}. + * + * @param modelFile the classification model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native + * code + */ + public static ImageClassifier createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, ImageClassifierOptions.builder().build()); + } + + /** + * Creates an {@link ImageClassifier} instance with a model buffer and the default {@link + * ImageClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native + * code + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static ImageClassifier createFromBuffer(final ByteBuffer modelBuffer) { + return createFromBufferAndOptions(modelBuffer, ImageClassifierOptions.builder().build()); + } + + /** + * Creates an {@link ImageClassifier} instance from {@link ImageClassifierOptions}. + * + * @param modelPath path of the classification model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native + * code + */ + public static ImageClassifier createFromFileAndOptions( + Context context, String modelPath, ImageClassifierOptions options) throws IOException { + return new ImageClassifier( + TaskJniUtils.createHandleFromFdAndOptions( + context, + new FdAndOptionsHandleProvider<ImageClassifierOptions>() { + @Override + public long createHandle( + int fileDescriptor, + long fileDescriptorLength, + long fileDescriptorOffset, + ImageClassifierOptions options) { + return initJniWithModelFdAndOptions( + fileDescriptor, fileDescriptorLength, fileDescriptorOffset, options); + } + }, + IMAGE_CLASSIFIER_NATIVE_LIB, + modelPath, + options)); + } + + /** + * Creates an {@link ImageClassifier} instance. + * + * @param modelFile the classification model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native + * code + */ + public static ImageClassifier createFromFileAndOptions( + File modelFile, final ImageClassifierOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new ImageClassifier( + TaskJniUtils.createHandleFromLibrary( + new TaskJniUtils.EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithModelFdAndOptions( + descriptor.getFd(), + /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, + /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, + options); + } + }, + IMAGE_CLASSIFIER_NATIVE_LIB)); + } + } + + /** + * Creates an {@link ImageClassifier} instance with a model buffer and {@link + * ImageClassifierOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws AssertionError if error occurs when creating {@link ImageClassifier} from the native + * code + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static ImageClassifier createFromBufferAndOptions( + final ByteBuffer modelBuffer, final ImageClassifierOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new ImageClassifier( + TaskJniUtils.createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer, options); + } + }, + IMAGE_CLASSIFIER_NATIVE_LIB)); + } + + /** + * Constructor to initialize the JNI with a pointer from C++. + * + * @param nativeHandle a pointer referencing memory allocated in C++ + */ + private ImageClassifier(long nativeHandle) { + super(nativeHandle); + } + + /** Options for setting up an ImageClassifier. */ + @UsedByReflection("image_classifier_jni.cc") + public static class ImageClassifierOptions { + // Not using AutoValue for this class because scoreThreshold cannot have default value + // (otherwise, the default value would override the one in the model metadata) and `Optional` is + // not an option here, because + // 1. java.util.Optional require Java 8 while we need to support Java 7. + // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the + // comments for labelAllowList. + private final String displayNamesLocale; + private final int maxResults; + private final float scoreThreshold; + private final boolean isScoreThresholdSet; + // As an open source project, we've been trying avoiding depending on common java libraries, + // such as Guava, because it may introduce conflicts with clients who also happen to use those + // libraries. Therefore, instead of using ImmutableList here, we convert the List into + // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less + // vulnerable. + private final List<String> labelAllowList; + private final List<String> labelDenyList; + private final int numThreads; + + public static Builder builder() { + return new Builder(); + } + + /** A builder that helps to configure an instance of ImageClassifierOptions. */ + public static class Builder { + private String displayNamesLocale = "en"; + private int maxResults = -1; + private float scoreThreshold; + private boolean isScoreThresholdSet = false; + private List<String> labelAllowList = new ArrayList<>(); + private List<String> labelDenyList = new ArrayList<>(); + private int numThreads = -1; + + private Builder() {} + + /** + * Sets the locale to use for display names specified through the TFLite Model Metadata, if + * any. + * + * <p>Defaults to English({@code "en"}). See the <a + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite + * Metadata schema file.</a> for the accepted pattern of locale. + */ + public Builder setDisplayNamesLocale(String displayNamesLocale) { + this.displayNamesLocale = displayNamesLocale; + return this; + } + + /** + * Sets the maximum number of top scored results to return. + * + * <p>If < 0, all results will be returned. If 0, an invalid argument error is returned. + * Defaults to -1. + * + * @throws IllegalArgumentException if maxResults is 0. + */ + public Builder setMaxResults(int maxResults) { + if (maxResults == 0) { + throw new IllegalArgumentException("maxResults cannot be 0."); + } + this.maxResults = maxResults; + return this; + } + + /** + * Sets the score threshold in [0,1). + * + * <p>It overrides the one provided in the model metadata (if any). Results below this value + * are rejected. + */ + public Builder setScoreThreshold(float scoreThreshold) { + this.scoreThreshold = scoreThreshold; + isScoreThresholdSet = true; + return this; + } + + /** + * Sets the optional allowlist of labels. + * + * <p>If non-empty, classifications whose label is not in this set will be filtered out. + * Duplicate or unknown labels are ignored. Mutually exclusive with labelDenyList. + */ + public Builder setLabelAllowList(List<String> labelAllowList) { + this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); + return this; + } + + /** + * Sets the optional denylist of labels. + * + * <p>If non-empty, classifications whose label is in this set will be filtered out. Duplicate + * or unknown labels are ignored. Mutually exclusive with labelAllowList. + */ + public Builder setLabelDenyList(List<String> labelDenyList) { + this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); + return this; + } + + /** + * Sets the number of threads to be used for TFLite ops that support multi-threading when + * running inference with CPU. Defaults to -1. + * + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the + * effect to let TFLite runtime set the value. + */ + public Builder setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + + public ImageClassifierOptions build() { + return new ImageClassifierOptions(this); + } + } + + @UsedByReflection("image_classifier_jni.cc") + public String getDisplayNamesLocale() { + return displayNamesLocale; + } + + @UsedByReflection("image_classifier_jni.cc") + public int getMaxResults() { + return maxResults; + } + + @UsedByReflection("image_classifier_jni.cc") + public float getScoreThreshold() { + return scoreThreshold; + } + + @UsedByReflection("image_classifier_jni.cc") + public boolean getIsScoreThresholdSet() { + return isScoreThresholdSet; + } + + @UsedByReflection("image_classifier_jni.cc") + public List<String> getLabelAllowList() { + return new ArrayList<>(labelAllowList); + } + + @UsedByReflection("image_classifier_jni.cc") + public List<String> getLabelDenyList() { + return new ArrayList<>(labelDenyList); + } + + @UsedByReflection("image_classifier_jni.cc") + public int getNumThreads() { + return numThreads; + } + + private ImageClassifierOptions(Builder builder) { + displayNamesLocale = builder.displayNamesLocale; + maxResults = builder.maxResults; + scoreThreshold = builder.scoreThreshold; + isScoreThresholdSet = builder.isScoreThresholdSet; + labelAllowList = builder.labelAllowList; + labelDenyList = builder.labelDenyList; + numThreads = builder.numThreads; + } + } + + /** + * Performs actual classification on the provided image. + * + * @param image a {@link TensorImage} object that represents an RGB image + * @throws AssertionError if error occurs when classifying the image from the native code + */ + public List<Classifications> classify(TensorImage image) { + return classify(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs actual classification on the provided image with {@link ImageProcessingOptions}. + * + * <p>{@link ImageClassifier} supports the following options: + * + * <ul> + * <li>Region of interest (ROI) (through {@link ImageProcessingOptions#Builder#setRoi}). It + * defaults to the entire image. + * <li>image rotation (through {@link ImageProcessingOptions#Builder#setOrientation}). It + * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}. + * </ul> + * + * @param image a {@link TensorImage} object that represents an RGB image + * @throws AssertionError if error occurs when classifying the image from the native code + */ + public List<Classifications> classify(TensorImage image, ImageProcessingOptions options) { + checkNotClosed(); + + // image_classifier_jni.cc expects an uint8 image. Convert image of other types into uint8. + TensorImage imageUint8 = + image.getDataType() == DataType.UINT8 + ? image + : TensorImage.createFrom(image, DataType.UINT8); + + Rect roi = + options.getRoi().isEmpty() + ? new Rect(0, 0, imageUint8.getWidth(), imageUint8.getHeight()) + : options.getRoi(); + + return classifyNative( + getNativeHandle(), + imageUint8.getBuffer(), + imageUint8.getWidth(), + imageUint8.getHeight(), + new int[] {roi.left, roi.top, roi.width(), roi.height()}, + options.getOrientation().getValue()); + } + + private static native long initJniWithModelFdAndOptions( + int fileDescriptor, + long fileDescriptorLength, + long fileDescriptorOffset, + ImageClassifierOptions options); + + private static native long initJniWithByteBuffer( + ByteBuffer modelBuffer, ImageClassifierOptions options); + + /** + * The native method to classify an image with the ROI and orientation. + * + * @param roi the ROI of the input image, an array representing the bounding box as {left, top, + * width, height} + * @param orientation the integer value corresponding to {@link + * ImageProcessingOptions#Orientation} + */ + private static native List<Classifications> classifyNative( + long nativeHandle, ByteBuffer image, int width, int height, int[] roi, int orientation); + + @Override + protected void deinit(long nativeHandle) { + deinitJni(nativeHandle); + } + + /** + * Native implementation to release memory pointed by the pointer. + * + * @param nativeHandle pointer to memory allocated + */ + private native void deinitJni(long nativeHandle); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml new file mode 100644 index 00000000..5fefccd0 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/AndroidManifest.xml @@ -0,0 +1,5 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.task.vision.detector"> + <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> +</manifest> diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD new file mode 100644 index 00000000..d0d541ab --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/BUILD @@ -0,0 +1,40 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "AndroidManifest.xml", +]) + +filegroup( + name = "object_detector_src", + srcs = glob(["**/*.java"]), +) + +android_library( + name = "object_detector", + srcs = glob(["*.java"]), + javacopts = JAVACOPTS, + manifest = "AndroidManifest.xml", + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "//tensorflow_lite_support/java/src/native/task/vision/detector:object_detector_native", + "@com_google_auto_value", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector:object-detector +aar_with_jni( + name = "object-detector", + android_library = ":object_detector", +) diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java new file mode 100644 index 00000000..007e032d --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/Detection.java @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.detector; + +import android.graphics.RectF; +import com.google.auto.value.AutoValue; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.label.Category; + +/** Represents one detected object in the results of a {@link ObjectDetector}. */ +@AutoValue +@UsedByReflection("object_detection_jni.cc") +public abstract class Detection { + + @UsedByReflection("object_detection_jni.cc") + public static Detection create(RectF boundingBox, List<Category> categories) { + return new AutoValue_Detection( + new RectF(boundingBox), Collections.unmodifiableList(new ArrayList<Category>(categories))); + } + + public abstract RectF getBoundingBox(); + + // Same reason for not using ImmutableList as stated in + // {@link ObjectDetector#ObjectDetectorOptions#labelAllowList}. + public abstract List<Category> getCategories(); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java new file mode 100644 index 00000000..75bc9836 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java @@ -0,0 +1,452 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.detector; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.annotations.UsedByReflection; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; + +/** + * Performs object detection on images. + * + * <p>The API expects a TFLite model with <a + * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. + * + * <p>The API supports models with one image input tensor and four output tensors. To be more + * specific, here are the requirements. + * + * <ul> + * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + * <ul> + * <li>image input of size {@code [batch x height x width x channels]}. + * <li>batch inference is not supported ({@code batch} is required to be 1). + * <li>only RGB inputs are supported ({@code channels} is required to be 3). + * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached + * to the metadata for input normalization. + * </ul> + * <li>Output tensors must be the 4 outputs of a {@code DetectionPostProcess} op, i.e: + * <ul> + * <li>Location tensor ({@code kTfLiteFloat32}): + * <ul> + * <li>tensor of size {@code [1 x num_results x 4]}, the inner array representing + * bounding boxes in the form [top, left, right, bottom]. + * <li>{@code BoundingBoxProperties} are required to be attached to the metadata and + * must specify {@code type=BOUNDARIES} and {@code coordinate_type=RATIO}. + * </ul> + * <li>Classes tensor ({@code kTfLiteFloat32}): + * <ul> + * <li>tensor of size {@code [1 x num_results]}, each value representing the integer + * index of a class. + * <li>if label maps are attached to the metadata as {@code TENSOR_VALUE_LABELS} + * associated files, they are used to convert the tensor values into labels. + * </ul> + * <li>scores tensor ({@code kTfLiteFloat32}): + * <ul> + * <li>tensor of size {@code [1 x num_results]}, each value representing the score of + * the detected object. + * </ul> + * <li>Number of detection tensor ({@code kTfLiteFloat32}): + * <ul> + * <li>integer num_results as a tensor of size {@code [1]}. + * </ul> + * </ul> + * </ul> + * + * <p>An example of such model can be found on <a + * href="https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1">TensorFlow + * Hub.</a>. + */ +public final class ObjectDetector extends BaseTaskApi { + + private static final String OBJECT_DETECTOR_NATIVE_LIB = "task_vision_jni"; + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; + + /** + * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}. + * + * @param modelPath path to the detection model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native + * code + */ + public static ObjectDetector createFromFile(Context context, String modelPath) + throws IOException { + return createFromFileAndOptions(context, modelPath, ObjectDetectorOptions.builder().build()); + } + + /** + * Creates an {@link ObjectDetector} instance from the default {@link ObjectDetectorOptions}. + * + * @param modelFile the detection model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native + * code + */ + public static ObjectDetector createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, ObjectDetectorOptions.builder().build()); + } + + /** + * Creates an {@link ObjectDetector} instance with a model buffer and the default {@link + * ObjectDetectorOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native + * code + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static ObjectDetector createFromBuffer(final ByteBuffer modelBuffer) { + return createFromBufferAndOptions(modelBuffer, ObjectDetectorOptions.builder().build()); + } + + /** + * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}. + * + * @param modelPath path to the detection model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native + * code + */ + public static ObjectDetector createFromFileAndOptions( + Context context, String modelPath, ObjectDetectorOptions options) throws IOException { + return new ObjectDetector( + TaskJniUtils.createHandleFromFdAndOptions( + context, + new FdAndOptionsHandleProvider<ObjectDetectorOptions>() { + @Override + public long createHandle( + int fileDescriptor, + long fileDescriptorLength, + long fileDescriptorOffset, + ObjectDetectorOptions options) { + return initJniWithModelFdAndOptions( + fileDescriptor, fileDescriptorLength, fileDescriptorOffset, options); + } + }, + OBJECT_DETECTOR_NATIVE_LIB, + modelPath, + options)); + } + + /** + * Creates an {@link ObjectDetector} instance from {@link ObjectDetectorOptions}. + * + * @param modelFile the detection model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native + * code + */ + public static ObjectDetector createFromFileAndOptions( + File modelFile, final ObjectDetectorOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return new ObjectDetector( + TaskJniUtils.createHandleFromLibrary( + new TaskJniUtils.EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithModelFdAndOptions( + descriptor.getFd(), + /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, + /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, + options); + } + }, + OBJECT_DETECTOR_NATIVE_LIB)); + } + } + + /** + * Creates an {@link ObjectDetector} instance with a model buffer and {@link + * ObjectDetectorOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws AssertionError if error occurs when creating {@link ObjectDetector} from the native + * code + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static ObjectDetector createFromBufferAndOptions( + final ByteBuffer modelBuffer, final ObjectDetectorOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new ObjectDetector( + TaskJniUtils.createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer(modelBuffer, options); + } + }, + OBJECT_DETECTOR_NATIVE_LIB)); + } + + /** + * Constructor to initialize the JNI with a pointer from C++. + * + * @param nativeHandle a pointer referencing memory allocated in C++ + */ + private ObjectDetector(long nativeHandle) { + super(nativeHandle); + } + + /** Options for setting up an ObjectDetector. */ + @UsedByReflection("object_detector_jni.cc") + public static class ObjectDetectorOptions { + // Not using AutoValue for this class because scoreThreshold cannot have default value + // (otherwise, the default value would override the one in the model metadata) and `Optional` is + // not an option here, because + // 1. java.util.Optional require Java 8 while we need to support Java 7. + // 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the + // comments for labelAllowList. + private final String displayNamesLocale; + private final int maxResults; + private final float scoreThreshold; + private final boolean isScoreThresholdSet; + // As an open source project, we've been trying avoiding depending on common java libraries, + // such as Guava, because it may introduce conflicts with clients who also happen to use those + // libraries. Therefore, instead of using ImmutableList here, we convert the List into + // unmodifiableList in setLabelAllowList() and setLabelDenyList() to make it less + // vulnerable. + private final List<String> labelAllowList; + private final List<String> labelDenyList; + private final int numThreads; + + public static Builder builder() { + return new Builder(); + } + + /** A builder that helps to configure an instance of ObjectDetectorOptions. */ + public static class Builder { + private String displayNamesLocale = "en"; + private int maxResults = -1; + private float scoreThreshold; + private boolean isScoreThresholdSet = false; + private List<String> labelAllowList = new ArrayList<>(); + private List<String> labelDenyList = new ArrayList<>(); + private int numThreads = -1; + + private Builder() {} + + /** + * Sets the locale to use for display names specified through the TFLite Model Metadata, if + * any. + * + * <p>Defaults to English({@code "en"}). See the <a + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite + * Metadata schema file.</a> for the accepted pattern of locale. + */ + public Builder setDisplayNamesLocale(String displayNamesLocale) { + this.displayNamesLocale = displayNamesLocale; + return this; + } + + /** + * Sets the maximum number of top-scored detection results to return. + * + * <p>If < 0, all available results will be returned. If 0, an invalid argument error is + * returned. Note that models may intrinsically be limited to returning a maximum number of + * results N: if the provided value here is above N, only N results will be returned. Defaults + * to -1. + * + * @throws IllegalArgumentException if maxResults is 0. + */ + public Builder setMaxResults(int maxResults) { + if (maxResults == 0) { + throw new IllegalArgumentException("maxResults cannot be 0."); + } + this.maxResults = maxResults; + return this; + } + + /** + * Sets the score threshold that overrides the one provided in the model metadata (if any). + * Results below this value are rejected. + */ + public Builder setScoreThreshold(float scoreThreshold) { + this.scoreThreshold = scoreThreshold; + this.isScoreThresholdSet = true; + return this; + } + + /** + * Sets the optional allow list of labels. + * + * <p>If non-empty, detection results whose label is not in this set will be filtered out. + * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelDenyList}. It + * will cause {@link AssertionError} when calling {@link #createFromFileAndOptions}, if both + * {@code labelDenyList} and {@code labelAllowList} are set. + */ + public Builder setLabelAllowList(List<String> labelAllowList) { + this.labelAllowList = Collections.unmodifiableList(new ArrayList<>(labelAllowList)); + return this; + } + + /** + * Sets the optional deny list of labels. + * + * <p>If non-empty, detection results whose label is in this set will be filtered out. + * Duplicate or unknown labels are ignored. Mutually exclusive with {@code labelAllowList}. It + * will cause {@link AssertionError} when calling {@link #createFromFileAndOptions}, if both + * {@code labelDenyList} and {@code labelAllowList} are set. + */ + public Builder setLabelDenyList(List<String> labelDenyList) { + this.labelDenyList = Collections.unmodifiableList(new ArrayList<>(labelDenyList)); + return this; + } + + /** + * Sets the number of threads to be used for TFLite ops that support multi-threading when + * running inference with CPU. Defaults to -1. + * + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the + * effect to let TFLite runtime set the value. + */ + public Builder setNumThreads(int numThreads) { + this.numThreads = numThreads; + return this; + } + + public ObjectDetectorOptions build() { + return new ObjectDetectorOptions(this); + } + } + + @UsedByReflection("object_detector_jni.cc") + public String getDisplayNamesLocale() { + return displayNamesLocale; + } + + @UsedByReflection("object_detector_jni.cc") + public int getMaxResults() { + return maxResults; + } + + @UsedByReflection("object_detector_jni.cc") + public float getScoreThreshold() { + return scoreThreshold; + } + + @UsedByReflection("object_detector_jni.cc") + public boolean getIsScoreThresholdSet() { + return isScoreThresholdSet; + } + + @UsedByReflection("object_detector_jni.cc") + public List<String> getLabelAllowList() { + return new ArrayList<>(labelAllowList); + } + + @UsedByReflection("object_detector_jni.cc") + public List<String> getLabelDenyList() { + return new ArrayList<>(labelDenyList); + } + + @UsedByReflection("object_detector_jni.cc") + public int getNumThreads() { + return numThreads; + } + + private ObjectDetectorOptions(Builder builder) { + displayNamesLocale = builder.displayNamesLocale; + maxResults = builder.maxResults; + scoreThreshold = builder.scoreThreshold; + isScoreThresholdSet = builder.isScoreThresholdSet; + labelAllowList = builder.labelAllowList; + labelDenyList = builder.labelDenyList; + numThreads = builder.numThreads; + } + } + + /** + * Performs actual detection on the provided image. + * + * @param image a {@link TensorImage} object that represents a RGB image + * @throws AssertionError if error occurs when processing the image from the native code + */ + public List<Detection> detect(TensorImage image) { + return detect(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs actual detection on the provided image. + * + * @param image a {@link TensorImage} object that represents a RGB image + * @param options {@link ObjectDetector} only supports image rotation (through {@link + * ImageProcessingOptions#Builder#setOrientation}) currently. The orientation of an image + * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}. + * @throws AssertionError if error occurs when processing the image from the native code + */ + public List<Detection> detect(TensorImage image, ImageProcessingOptions options) { + checkNotClosed(); + + // object_detector_jni.cc expects an uint8 image. Convert image of other types into uint8. + TensorImage imageUint8 = + image.getDataType() == DataType.UINT8 + ? image + : TensorImage.createFrom(image, DataType.UINT8); + return detectNative( + getNativeHandle(), + imageUint8.getBuffer(), + imageUint8.getWidth(), + imageUint8.getHeight(), + options.getOrientation().getValue()); + } + + private static native long initJniWithModelFdAndOptions( + int fileDescriptor, + long fileDescriptorLength, + long fileDescriptorOffset, + ObjectDetectorOptions options); + + private static native long initJniWithByteBuffer( + ByteBuffer modelBuffer, ObjectDetectorOptions options); + + private static native List<Detection> detectNative( + long nativeHandle, ByteBuffer image, int width, int height, int orientation); + + @Override + protected void deinit(long nativeHandle) { + deinitJni(nativeHandle); + } + + /** + * Native implementation to release memory pointed by the pointer. + * + * @param nativeHandle pointer to memory allocated + */ + private native void deinitJni(long nativeHandle); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml new file mode 100644 index 00000000..991d4816 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/AndroidManifest.xml @@ -0,0 +1,5 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android" + package="org.tensorflow.lite.task.vision.segmenter"> + <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> +</manifest> diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD new file mode 100644 index 00000000..506d5bfb --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/BUILD @@ -0,0 +1,41 @@ +load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "AndroidManifest.xml", +]) + +filegroup( + name = "image_segmenter_src", + srcs = glob(["**/*.java"]), +) + +android_library( + name = "image_segmenter", + srcs = glob(["*.java"]), + # TODO(b/163039980): Use JAVACOPTS in TF. "-Xep:RemoveUnusedImports:ERROR" wierdly break the build. + javacopts = ["-source 7 -target 7"], + manifest = "AndroidManifest.xml", + deps = [ + "//tensorflow_lite_support/java:tensorflowlite_support_java", + "//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api", + "//tensorflow_lite_support/java/src/native/task/vision/segmenter:image_segmenter_native", + "@com_google_auto_value", + "@maven//:androidx_annotation_annotation", + "@org_tensorflow//tensorflow/lite/java:tensorflowlite_java", + ], +) + +# AAR target for OSS release. +# +# bazel build -c opt --config=monolithic --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter:image-segmenter +aar_with_jni( + name = "image-segmenter", + android_library = ":image_segmenter", +) diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java new file mode 100644 index 00000000..09416d08 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ColoredLabel.java @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.segmenter; + +import android.graphics.Color; +import android.os.Build; +import androidx.annotation.RequiresApi; +import com.google.auto.value.AutoValue; +import org.tensorflow.lite.annotations.UsedByReflection; + +/** Represents a label associated with a color for display purposes. */ +@AutoValue +@UsedByReflection("image_segmentation_jni.cc") +public abstract class ColoredLabel { + + /** + * Creates a {@link ColoredLabel} object with an ARGB color int. + * + * @param label the label string, as provided in the label map packed in the TFLite Model + * Metadata. + * @param displayName the display name of label, as configured through {@link + * ImageSegmenter#ImageSegmenterOptions#Builder#setDisplayNamesLocale} + * @param argb the color components for the label in ARGB. See <a + * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android + * Color ints.</a> for more details. + */ + @UsedByReflection("image_segmentation_jni.cc") + public static ColoredLabel create(String label, String displayName, int argb) { + return new AutoValue_ColoredLabel(label, displayName, argb); + } + + /** + * Creates a {@link ColoredLabel} object with a {@link Color} instance. + * + * @param label the label string, as provided in the label map packed in the TFLite Model + * Metadata. + * @param displayName the display name of label, as configured through {@link + * ImageSegmenter#ImageSegmenterOptions#Builder#setDisplayNamesLocale} + * @param color the color components for the label. The Color instatnce is supported on Android + * API level 26 and above. For API level lower than 26, use {@link #create(String, String, + * int)}. See <a + * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android + * Color instances.</a> for more details. + */ + @RequiresApi(Build.VERSION_CODES.O) + public static ColoredLabel create(String label, String displayName, Color color) { + return new AutoValue_ColoredLabel(label, displayName, color.toArgb()); + } + + public abstract String getlabel(); + + public abstract String getDisplayName(); + + /** + * Gets the ARGB int that represents the color. + * + * <p>See <a + * href="https://developer.android.com/reference/android/graphics/Color#color-ints">Android Color + * ints.</a> for more details. + */ + public abstract int getArgb(); + + /** + * Gets the {@link Color} instance of the underlying color. + * + * <p>The Color instatnce is supported on Android API level 26 and above. For API level lower than + * 26, use {@link #getArgb()}. See <a + * href="https://developer.android.com/reference/android/graphics/Color#color-instances">Android + * Color instances.</a> for more details. + */ + @RequiresApi(Build.VERSION_CODES.O) + public Color getColor() { + return Color.valueOf(getArgb()); + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java new file mode 100644 index 00000000..bd90790f --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java @@ -0,0 +1,377 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.segmenter; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.tensorflow.lite.DataType; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.task.core.BaseTaskApi; +import org.tensorflow.lite.task.core.TaskJniUtils; +import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; +import org.tensorflow.lite.task.core.vision.ImageProcessingOptions; + +/** + * Performs segmentation on images. + * + * <p>The API expects a TFLite model with <a + * href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>. + * + * <p>The API supports models with one image input tensor and one output tensor. To be more + * specific, here are the requirements. + * + * <ul> + * <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + * <ul> + * <li>image input of size {@code [batch x height x width x channels]}. + * <li>batch inference is not supported ({@code batch} is required to be 1). + * <li>only RGB inputs are supported ({@code channels} is required to be 3). + * <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached + * to the metadata for input normalization. + * </ul> + * <li>Output image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + * <ul> + * <li>tensor of size {@code [batch x mask_height x mask_width x num_classes]}, where {@code + * batch} is required to be 1, {@code mask_width} and {@code mask_height} are the + * dimensions of the segmentation masks produced by the model, and {@code num_classes} + * is the number of classes supported by the model. + * <li>optional (but recommended) label map(s) can be attached as AssociatedFile-s with type + * TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if + * any) is used to fill the class name, i.e. {@link ColoredLabel#getClassName} of the + * results. The display name, i.e. {@link ColoredLabel#getDisplayName}, is filled from + * the AssociatedFile (if any) whose locale matches the `display_names_locale` field of + * the `ImageSegmenterOptions` used at creation time ("en" by default, i.e. English). If + * none of these are available, only the `index` field of the results will be filled. + * </ul> + * </ul> + * + * <p>An example of such model can be found on <a + * href="https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1">TensorFlow Hub.</a>. + */ +public final class ImageSegmenter extends BaseTaskApi { + + private static final String IMAGE_SEGMENTER_NATIVE_LIB = "task_vision_jni"; + private static final int OPTIONAL_FD_LENGTH = -1; + private static final int OPTIONAL_FD_OFFSET = -1; + + private final OutputType outputType; + + /** + * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. + * + * @param modelPath path of the segmentation model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native + * code + */ + public static ImageSegmenter createFromFile(Context context, String modelPath) + throws IOException { + return createFromFileAndOptions(context, modelPath, ImageSegmenterOptions.builder().build()); + } + + /** + * Creates an {@link ImageSegmenter} instance from the default {@link ImageSegmenterOptions}. + * + * @param modelFile the segmentation model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native + * code + */ + public static ImageSegmenter createFromFile(File modelFile) throws IOException { + return createFromFileAndOptions(modelFile, ImageSegmenterOptions.builder().build()); + } + + /** + * Creates an {@link ImageSegmenter} instance with a model buffer and the default {@link + * ImageSegmenterOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native + * code + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static ImageSegmenter createFromBuffer(final ByteBuffer modelBuffer) { + return createFromBufferAndOptions(modelBuffer, ImageSegmenterOptions.builder().build()); + } + + /** + * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. + * + * @param modelPath path of the segmentation model with metadata in the assets + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native + * code + */ + public static ImageSegmenter createFromFileAndOptions( + Context context, String modelPath, final ImageSegmenterOptions options) throws IOException { + try (AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(modelPath)) { + return createFromModelFdAndOptions( + /*fileDescriptor=*/ assetFileDescriptor.getParcelFileDescriptor().getFd(), + /*fileDescriptorLength=*/ assetFileDescriptor.getLength(), + /*fileDescriptorOffset=*/ assetFileDescriptor.getStartOffset(), + options); + } + } + + /** + * Creates an {@link ImageSegmenter} instance from {@link ImageSegmenterOptions}. + * + * @param modelFile the segmentation model {@link File} instance + * @throws IOException if an I/O error occurs when loading the tflite model + * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native + * code + */ + public static ImageSegmenter createFromFileAndOptions( + File modelFile, final ImageSegmenterOptions options) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + return createFromModelFdAndOptions( + /*fileDescriptor=*/ descriptor.getFd(), + /*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH, + /*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET, + options); + } + } + + /** + * Creates an {@link ImageSegmenter} instance with a model buffer and {@link + * ImageSegmenterOptions}. + * + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model + * @throws AssertionError if error occurs when creating {@link ImageSegmenter} from the native + * code + * @throws IllegalArgumentException if the model buffer is not a direct {@link ByteBuffer} or a + * {@link MappedByteBuffer} + */ + public static ImageSegmenter createFromBufferAndOptions( + final ByteBuffer modelBuffer, final ImageSegmenterOptions options) { + if (!(modelBuffer.isDirect() || modelBuffer instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return new ImageSegmenter( + TaskJniUtils.createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithByteBuffer( + modelBuffer, + options.getDisplayNamesLocale(), + options.getOutputType().getValue(), + options.getNumThreads()); + } + }, + IMAGE_SEGMENTER_NATIVE_LIB), + options.getOutputType()); + } + + /** + * Constructor to initialize the JNI with a pointer from C++. + * + * @param nativeHandle a pointer referencing memory allocated in C++ + */ + private ImageSegmenter(long nativeHandle, OutputType outputType) { + super(nativeHandle); + this.outputType = outputType; + } + + /** Options for setting up an {@link ImageSegmenter}. */ + @AutoValue + public abstract static class ImageSegmenterOptions { + private static final String DEFAULT_DISPLAY_NAME_LOCALE = "en"; + private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK; + private static final int NUM_THREADS = -1; + + public abstract String getDisplayNamesLocale(); + + public abstract OutputType getOutputType(); + + public abstract int getNumThreads(); + + public static Builder builder() { + return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder() + .setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE) + .setOutputType(DEFAULT_OUTPUT_TYPE) + .setNumThreads(NUM_THREADS); + } + + /** Builder for {@link ImageSegmenterOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + + /** + * Sets the locale to use for display names specified through the TFLite Model Metadata, if + * any. + * + * <p>Defaults to English({@code "en"}). See the <a + * href="https://github.com/tensorflow/tflite-support/blob/3ce83f0cfe2c68fecf83e019f2acc354aaba471f/tensorflow_lite_support/metadata/metadata_schema.fbs#L147">TFLite + * Metadata schema file.</a> for the accepted pattern of locale. + */ + public abstract Builder setDisplayNamesLocale(String displayNamesLocale); + + public abstract Builder setOutputType(OutputType outputType); + + /** + * Sets the number of threads to be used for TFLite ops that support multi-threading when + * running inference with CPU. Defaults to -1. + * + * <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the + * effect to let TFLite runtime set the value. + */ + public abstract Builder setNumThreads(int numThreads); + + public abstract ImageSegmenterOptions build(); + } + } + + /** + * Performs actual segmentation on the provided image. + * + * @param image a {@link TensorImage} object that represents an RGB image + * @return results of performing image segmentation. Note that at the time, a single {@link + * Segmentation} element is expected to be returned. The result is stored in a {@link List} + * for later extension to e.g. instance segmentation models, which may return one segmentation + * per object. + * @throws AssertionError if error occurs when segmenting the image from the native code + */ + public List<Segmentation> segment(TensorImage image) { + return segment(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs actual segmentation on the provided image with {@link ImageProcessingOptions}. + * + * @param image a {@link TensorImage} object that represents an RGB image + * @param options {@link ImageSegmenter} only supports image rotation (through {@link + * ImageProcessingOptions#Builder#setOrientation}) currently. The orientation of an image + * defaults to {@link ImageProcessingOptions#Orientation#TOP_LEFT}. + * @return results of performing image segmentation. Note that at the time, a single {@link + * Segmentation} element is expected to be returned. The result is stored in a {@link List} + * for later extension to e.g. instance segmentation models, which may return one segmentation + * per object. + * @throws AssertionError if error occurs when segmenting the image from the native code + */ + public List<Segmentation> segment(TensorImage image, ImageProcessingOptions options) { + checkNotClosed(); + + // image_segmenter_jni.cc expects an uint8 image. Convert image of other types into uint8. + TensorImage imageUint8 = + image.getDataType() == DataType.UINT8 + ? image + : TensorImage.createFrom(image, DataType.UINT8); + List<byte[]> maskByteArrays = new ArrayList<>(); + List<ColoredLabel> coloredLabels = new ArrayList<>(); + int[] maskShape = new int[2]; + segmentNative( + getNativeHandle(), + imageUint8.getBuffer(), + imageUint8.getWidth(), + imageUint8.getHeight(), + maskByteArrays, + maskShape, + coloredLabels, + options.getOrientation().getValue()); + + List<ByteBuffer> maskByteBuffers = new ArrayList<>(); + for (byte[] bytes : maskByteArrays) { + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + // Change the byte order to little_endian, since the buffers were generated in jni. + byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + maskByteBuffers.add(byteBuffer); + } + + return Arrays.asList( + Segmentation.create( + outputType, + outputType.createMasksFromBuffer(maskByteBuffers, maskShape), + coloredLabels)); + } + + private static ImageSegmenter createFromModelFdAndOptions( + final int fileDescriptor, + final long fileDescriptorLength, + final long fileDescriptorOffset, + final ImageSegmenterOptions options) { + long nativeHandle = + TaskJniUtils.createHandleFromLibrary( + new EmptyHandleProvider() { + @Override + public long createHandle() { + return initJniWithModelFdAndOptions( + fileDescriptor, + fileDescriptorLength, + fileDescriptorOffset, + options.getDisplayNamesLocale(), + options.getOutputType().getValue(), + options.getNumThreads()); + } + }, + IMAGE_SEGMENTER_NATIVE_LIB); + return new ImageSegmenter(nativeHandle, options.getOutputType()); + } + + private static native long initJniWithModelFdAndOptions( + int fileDescriptor, + long fileDescriptorLength, + long fileDescriptorOffset, + String displayNamesLocale, + int outputType, + int numThreads); + + private static native long initJniWithByteBuffer( + ByteBuffer modelBuffer, String displayNamesLocale, int outputType, int numThreads); + + /** + * The native method to segment the image. + * + * <p>{@code maskBuffers}, {@code maskShape}, {@code coloredLabels} will be updated in the native + * layer. + */ + private static native void segmentNative( + long nativeHandle, + ByteBuffer image, + int width, + int height, + List<byte[]> maskByteArrays, + int[] maskShape, + List<ColoredLabel> coloredLabels, + int orientation); + + @Override + protected void deinit(long nativeHandle) { + deinitJni(nativeHandle); + } + + /** + * Native implementation to release memory pointed by the pointer. + * + * @param nativeHandle pointer to memory allocated + */ + private native void deinitJni(long nativeHandle); +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java new file mode 100644 index 00000000..03d82c6d --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java @@ -0,0 +1,145 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.segmenter; + +import static org.tensorflow.lite.DataType.FLOAT32; +import static org.tensorflow.lite.DataType.UINT8; +import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument; +import static org.tensorflow.lite.support.image.ColorSpaceType.GRAYSCALE; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import org.tensorflow.lite.support.image.TensorImage; +import org.tensorflow.lite.support.tensorbuffer.TensorBuffer; + +/** + * Output mask type. This allows specifying the type of post-processing to perform on the raw model + * results. + */ +public enum OutputType { + + /** + * Gives a single output mask where each pixel represents the class which the pixel in the + * original image was predicted to belong to. + */ + CATEGORY_MASK(0) { + /** + * {@inheritDoc} + * + * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if the + * color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE} + */ + @Override + void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) { + checkArgument( + masks.size() == 1, + "CATRGORY_MASK only allows one TensorImage in the list, providing " + masks.size()); + + TensorImage mask = masks.get(0); + checkArgument( + mask.getColorSpaceType() == GRAYSCALE, + "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing " + + mask.getColorSpaceType()); + } + + /** + * {@inheritDoc} + * + * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the list + */ + @Override + List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) { + checkArgument( + buffers.size() == 1, + "CATRGORY_MASK only allows one mask in the buffer list, providing " + buffers.size()); + + List<TensorImage> masks = new ArrayList<>(); + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8); + tensorBuffer.loadBuffer(buffers.get(0), maskShape); + TensorImage tensorImage = new TensorImage(UINT8); + tensorImage.load(tensorBuffer, GRAYSCALE); + masks.add(tensorImage); + + return masks; + } + }, + + /** + * Gives a list of output masks where, for each mask, each pixel represents the prediction + * confidence, usually in the [0, 1] range. + */ + CONFIDENCE_MASK(1) { + /** + * {@inheritDoc} + * + * @throws IllegalArgumentException if more the size of the masks list does not match the size + * of the coloredlabels list, or if the color space type of the any mask is not {@link + * ColorSpaceType#GRAYSCALE} + */ + @Override + void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) { + checkArgument( + masks.size() == coloredLabels.size(), + String.format( + "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of" + + " coloredLabels (%d).", + masks.size(), coloredLabels.size())); + + for (TensorImage mask : masks) { + checkArgument( + mask.getColorSpaceType() == GRAYSCALE, + "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing " + + mask.getColorSpaceType()); + } + } + + @Override + List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) { + List<TensorImage> masks = new ArrayList<>(); + for (ByteBuffer buffer : buffers) { + TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32); + tensorBuffer.loadBuffer(buffer, maskShape); + TensorImage tensorImage = new TensorImage(FLOAT32); + tensorImage.load(tensorBuffer, GRAYSCALE); + masks.add(tensorImage); + } + return masks; + } + }; + + public int getValue() { + return value; + } + + /** + * Verifies that the given list of masks matches the list of colored labels. + * + * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the + * output type + */ + abstract void assertMasksMatchColoredLabels( + List<TensorImage> masks, List<ColoredLabel> coloredLabels); + + /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */ + abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape); + + private final int value; + + private OutputType(int value) { + this.value = value; + } +} diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java new file mode 100644 index 00000000..018482c7 --- /dev/null +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/Segmentation.java @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.vision.segmenter; + +import com.google.auto.value.AutoValue; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.tensorflow.lite.support.image.TensorImage; + +/** Represents the segmentation result of an {@link ImageSegmenter}. */ +@AutoValue +public abstract class Segmentation { + + /** + * Creates a {@link Segmentation} object. + * + * <p>{@link Segmentation} provides two types of outputs as indicated through {@link OutputType}: + * + * <p>{@link OutputType#CATEGORY_MASK}: the result contains a single category mask, which is a + * grayscale {@link TensorImage} with shape (height, width), in row major order. The value of each + * pixel in this mask represents the class to which the pixel in the mask belongs. The pixel + * values are in 1:1 corresponding with the colored labels, i.e. a pixel with value {@code i} is + * associated with {@code coloredLabels.get(i)}. + * + * <p>{@link OutputType#CONFIDENCE_MASK}: the result contains a list of confidence masks, which + * are in 1:1 correspondance with the colored labels, i.e. {@link masks.get(i)} is associated with + * {@code coloredLabels.get(i)}. Each confidence mask is a grayscale {@link TensorImage} with + * shape (height, width), in row major order. The value of each pixel in these masks represents + * the confidence score for this particular class. + * + * <p>IMPORTANT: segmentation masks are not direcly suited for display, in particular:<br> + * \* they are relative to the unrotated input frame, i.e. *not* taking into account the {@code + * Orientation} flag of the input FrameBuffer, <br> + * \* their dimensions are intrinsic to the model, i.e. *not* dependent on the input FrameBuffer + * dimensions. + * + * <p>Example of such post-processing, assuming: <br> + * \* an input FrameBuffer with width=640, height=480, orientation=kLeftBottom (i.e. the image + * will be rotated 90° clockwise during preprocessing to make it "upright"), <br> + * \* a model outputting masks of size 224x224. <br> + * In order to be directly displayable on top of the input image assumed to be displayed *with* + * the {@code Orientation} flag taken into account (according to the <a + * href="http://jpegclub.org/exif_orientation.html">EXIF specification</a>), the masks need to be: + * re-scaled to 640 x 480, then rotated 90° clockwise. + * + * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the + * {@code outputType} + */ + static Segmentation create( + OutputType outputType, List<TensorImage> masks, List<ColoredLabel> coloredLabels) { + outputType.assertMasksMatchColoredLabels(masks, coloredLabels); + + return new AutoValue_Segmentation( + outputType, + Collections.unmodifiableList(new ArrayList<TensorImage>(masks)), + Collections.unmodifiableList(new ArrayList<ColoredLabel>(coloredLabels))); + } + + public abstract OutputType getOutputType(); + + // As an open source project, we've been trying avoiding depending on common java libraries, + // such as Guava, because it may introduce conflicts with clients who also happen to use those + // libraries. Therefore, instead of using ImmutableList here, we convert the List into + // unmodifiableList in create() to make it less vulnerable. + public abstract List<TensorImage> getMasks(); + + public abstract List<ColoredLabel> getColoredLabels(); +} diff --git a/tensorflow_lite_support/java/src/native/task/core/BUILD b/tensorflow_lite_support/java/src/native/task/core/BUILD new file mode 100644 index 00000000..d4dd7ab3 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/core/BUILD @@ -0,0 +1,16 @@ +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +# Default provider for BuiltInOpResover. Create your own target, overwrite the +# function to provide a MutableOpResolver for customized OPs and/or a subset of +# builtin OPs. +cc_library( + name = "builtin_op_resolver", + srcs = ["builtin_op_resolver.cc"], + deps = [ + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) diff --git a/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc new file mode 100644 index 00000000..050f49fc --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/register.h" + +namespace tflite { +namespace task { +// Default provider for OpResolver, provides BuiltinOpResolver. +std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT + return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>( + new tflite::ops::builtin::BuiltinOpResolver); +} + +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/java/src/native/task/text/BUILD b/tensorflow_lite_support/java/src/native/task/text/BUILD new file mode 100644 index 00000000..a27aba52 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/text/BUILD @@ -0,0 +1,34 @@ +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "task_text_native", + srcs = [ + ":libtask_text_jni.so", + ], +) + +tflite_jni_binary( + name = "libtask_text_jni.so", + srcs = [ + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier:bert_nl_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/text/qa:bert_question_answerer_jni.cc", + ], + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + deps = [ + "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier", + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", + "//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni_utils", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], +) diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD new file mode 100644 index 00000000..88f3e9be --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/BUILD @@ -0,0 +1,63 @@ +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "nl_classifier_jni.cc", +]) + +# Default native target for nl_classifier to provide BuiltInOpResolver. +cc_library( + name = "nl_classifier_native", + srcs = [ + ":libtask_text_jni.so", + ], +) + +# Note: "libtask_text_jni" is hardcoded in Java to look up the .so, therefore +# the name should remain the same when creating customized version of +# nl_classifier_native +tflite_jni_binary( + name = "libtask_text_jni.so", + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + deps = [ + ":native_without_resolver", + "//tensorflow_lite_support/java/src/native/task/core:builtin_op_resolver", + ], +) + +# Shared native logic for NLClassifier. Combine this target and customized +# version of op_resolver to build customized nl_classifier_native target. +cc_library( + name = "native_without_resolver", + srcs = [ + "nl_classifier_jni.cc", + ], + deps = [ + ":nl_classifier_jni_utils", + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], + alwayslink = 1, +) + +cc_library( + name = "nl_classifier_jni_utils", + srcs = [ + "nl_classifier_jni_utils.cc", + ], + hdrs = [ + "nl_classifier_jni_utils.h", + ], + deps = [ + "//tensorflow_lite_support/cc/task/text/nlclassifier:nl_classifier", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + ], +) diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD new file mode 100644 index 00000000..49f3f4e4 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/BUILD @@ -0,0 +1,31 @@ +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "bert_nl_classifier_jni.cc", +]) + +cc_library( + name = "bert_nl_classifier_native", + srcs = [ + ":libtask_text_jni.so", + ], +) + +tflite_jni_binary( + name = "libtask_text_jni.so", + srcs = [ + "bert_nl_classifier_jni.cc", + ], + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + deps = [ + "//tensorflow_lite_support/cc/task/text/nlclassifier:bert_nl_classifier", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + "//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni_utils", + ], +) diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc new file mode 100644 index 00000000..1edb3507 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" +#include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h" + +namespace { + +using ::tflite::support::utils::GetMappedFileBuffer; +using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::text::nlclassifier::BertNLClassifier; +using ::tflite::task::text::nlclassifier::RunClassifier; + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni( + JNIEnv* env, jobject thiz, jlong native_handle) { + delete reinterpret_cast<BertNLClassifier*>(native_handle); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer( + JNIEnv* env, jclass thiz, jobject model_buffer) { + auto model = GetMappedFileBuffer(env, model_buffer); + tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status = + BertNLClassifier::CreateFromBuffer(model.data(), model.size()); + if (status.ok()) { + return reinterpret_cast<jlong>(status->release()); + } else { + ThrowException(env, kAssertionError, + "Error occurred when initializing Bert NLClassifier: %s", + status.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor( + JNIEnv* env, jclass thiz, jint fd) { + tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status = + BertNLClassifier::CreateFromFd(fd); + if (status.ok()) { + return reinterpret_cast<jlong>(status->release()); + } else { + ThrowException(env, kAssertionError, + "Error occurred when initializing Bert NLClassifier: %s", + status.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative( + JNIEnv* env, jclass clazz, jlong native_handle, jstring text) { + return RunClassifier(env, native_handle, text); +} + +} // namespace diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc new file mode 100644 index 00000000..d2ace753 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc @@ -0,0 +1,135 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/op_resolver.h" +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" +#include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h" + +namespace tflite { +namespace task { +// To be provided by a link-time library +extern std::unique_ptr<OpResolver> CreateOpResolver(); + +} // namespace task +} // namespace tflite + +namespace { + +using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::GetMappedFileBuffer; +using ::tflite::support::utils::JStringToString; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::text::nlclassifier::NLClassifier; +using ::tflite::task::text::nlclassifier::NLClassifierOptions; +using ::tflite::task::text::nlclassifier::RunClassifier; + + +NLClassifierOptions ConvertJavaNLClassifierOptions( + JNIEnv* env, jobject java_nl_classifier_options) { + jclass nl_classifier_options_class = env->FindClass( + "org/tensorflow/lite/task/text/nlclassifier/" + "NLClassifier$NLClassifierOptions"); + jmethodID input_tensor_index_method_id = + env->GetMethodID(nl_classifier_options_class, "inputTensorIndex", "()I"); + jmethodID output_score_tensor_index_method_id = env->GetMethodID( + nl_classifier_options_class, "outputScoreTensorIndex", "()I"); + jmethodID output_label_tensor_index_method_id = env->GetMethodID( + nl_classifier_options_class, "outputLabelTensorIndex", "()I"); + jmethodID input_tensor_name_method_id = env->GetMethodID( + nl_classifier_options_class, "inputTensorName", "()Ljava/lang/String;"); + jmethodID output_score_tensor_name_method_id = + env->GetMethodID(nl_classifier_options_class, "outputScoreTensorName", + "()Ljava/lang/String;"); + jmethodID output_label_tensor_name_method_id = + env->GetMethodID(nl_classifier_options_class, "outputLabelTensorName", + "()Ljava/lang/String;"); + + return { + .input_tensor_index = env->CallIntMethod(java_nl_classifier_options, + input_tensor_index_method_id), + .output_score_tensor_index = env->CallIntMethod( + java_nl_classifier_options, output_score_tensor_index_method_id), + .output_label_tensor_index = env->CallIntMethod( + java_nl_classifier_options, output_label_tensor_index_method_id), + .input_tensor_name = JStringToString( + env, (jstring)env->CallObjectMethod(java_nl_classifier_options, + input_tensor_name_method_id)), + .output_score_tensor_name = JStringToString( + env, + (jstring)env->CallObjectMethod(java_nl_classifier_options, + output_score_tensor_name_method_id)), + .output_label_tensor_name = JStringToString( + env, + (jstring)env->CallObjectMethod(java_nl_classifier_options, + output_label_tensor_name_method_id)), + }; +} + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_deinitJni( + JNIEnv* env, jobject thiz, jlong native_handle) { + delete reinterpret_cast<NLClassifier*>(native_handle); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithByteBuffer( + JNIEnv* env, jclass thiz, jobject nl_classifier_options, + jobject model_buffer) { + auto model = GetMappedFileBuffer(env, model_buffer); + tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status = + NLClassifier::CreateFromBufferAndOptions( + model.data(), model.size(), + ConvertJavaNLClassifierOptions(env, nl_classifier_options), + tflite::task::CreateOpResolver()); + + if (status.ok()) { + return reinterpret_cast<jlong>(status->release()); + } else { + ThrowException(env, kAssertionError, + "Error occurred when initializing NLClassifier: %s", + status.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_initJniWithFileDescriptor( + JNIEnv* env, jclass thiz, jobject nl_classifier_options, jint fd) { + tflite::support::StatusOr<std::unique_ptr<NLClassifier>> status = + NLClassifier::CreateFromFdAndOptions( + fd, ConvertJavaNLClassifierOptions(env, nl_classifier_options), + tflite::task::CreateOpResolver()); + if (status.ok()) { + return reinterpret_cast<jlong>(status->release()); + } else { + ThrowException(env, kAssertionError, + "Error occurred when initializing NLClassifier: %s", + status.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_NLClassifier_classifyNative( + JNIEnv* env, jclass thiz, jlong native_handle, jstring text) { + return RunClassifier(env, native_handle, text); +} + +} // namespace diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc new file mode 100644 index 00000000..c358bee1 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +using ::tflite::support::utils::ConvertVectorToArrayList; +using ::tflite::support::utils::JStringToString; +using ::tflite::task::core::Category; +using ::tflite::task::text::nlclassifier::NLClassifier; + +jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) { + auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle); + + auto results = nl_classifier->Classify(JStringToString(env, text)); + jclass category_class = + env->FindClass("org/tensorflow/lite/support/label/Category"); + jmethodID category_init = + env->GetMethodID(category_class, "<init>", "(Ljava/lang/String;F)V"); + + return ConvertVectorToArrayList<Category>( + env, results, + [env, category_class, category_init](const Category& category) { + jstring class_name = env->NewStringUTF(category.class_name.data()); + // Convert double to float as Java interface exposes float as scores. + jobject jcategory = + env->NewObject(category_class, category_init, class_name, + static_cast<float>(category.score)); + env->DeleteLocalRef(class_name); + return jcategory; + }); +} + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h new file mode 100644 index 00000000..2c59ab50 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TAKS_TEXT_NLCLASSIFIER_NL_CLASSIFIER_JNI_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TAKS_TEXT_NLCLASSIFIER_NL_CLASSIFIER_JNI_UTILS_H_ + +#include <jni.h> + +namespace tflite { +namespace task { +namespace text { +namespace nlclassifier { + +jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text); + +} // namespace nlclassifier +} // namespace text +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TAKS_TEXT_NLCLASSIFIER_NL_CLASSIFIER_JNI_UTILS_H_ diff --git a/tensorflow_lite_support/java/src/native/task/text/qa/BUILD b/tensorflow_lite_support/java/src/native/task/text/qa/BUILD new file mode 100644 index 00000000..9753e329 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/text/qa/BUILD @@ -0,0 +1,30 @@ +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "bert_question_answerer_jni.cc", +]) + +tflite_jni_binary( + name = "libtask_text_jni.so", + srcs = [ + "bert_question_answerer_jni.cc", + ], + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + deps = [ + "//tensorflow_lite_support/cc/task/text/qa:bert_question_answerer", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + ], +) + +cc_library( + name = "bert_question_answerer_native", + srcs = [ + ":libtask_text_jni.so", + ], +) diff --git a/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc b/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc new file mode 100644 index 00000000..92f93467 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace { + +using ::tflite::support::utils::ConvertVectorToArrayList; +using ::tflite::support::utils::GetMappedFileBuffer; +using ::tflite::support::utils::JStringToString; +using ::tflite::task::text::qa::BertQuestionAnswerer; +using ::tflite::task::text::qa::QaAnswer; +using ::tflite::task::text::qa::QuestionAnswerer; + +constexpr int kInvalidPointer = 0; + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni( + JNIEnv* env, jobject thiz, jlong native_handle) { + delete reinterpret_cast<QuestionAnswerer*>(native_handle); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithModelWithMetadataByteBuffers( + JNIEnv* env, jclass thiz, jobjectArray model_buffers) { + absl::string_view model_with_metadata = + GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); + + tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status = + BertQuestionAnswerer::CreateFromBuffer( + model_with_metadata.data(), model_with_metadata.size()); + if (status.ok()) { + return reinterpret_cast<jlong>(status->release()); + } else { + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor( + JNIEnv* env, jclass thiz, jint fd) { + tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status = + BertQuestionAnswerer::CreateFromFd(fd); + if (status.ok()) { + return reinterpret_cast<jlong>(status->release()); + } else { + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers( + JNIEnv* env, jclass thiz, jobjectArray model_buffers) { + absl::string_view model = + GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); + absl::string_view vocab = + GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1)); + + tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status = + BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer( + model.data(), model.size(), vocab.data(), vocab.size()); + if (status.ok()) { + return reinterpret_cast<jlong>(status->release()); + } else { + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers( + JNIEnv* env, jclass thiz, jobjectArray model_buffers) { + absl::string_view model = + GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0)); + absl::string_view sp_model = + GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1)); + + tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status = + BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer( + model.data(), model.size(), sp_model.data(), sp_model.size()); + if (status.ok()) { + return reinterpret_cast<jlong>(status->release()); + } else { + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative( + JNIEnv* env, jclass thiz, jlong native_handle, jstring context, + jstring question) { + auto* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle); + + std::vector<QaAnswer> results = question_answerer->Answer( + JStringToString(env, context), JStringToString(env, question)); + jclass qa_answer_class = + env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer"); + jmethodID qa_answer_ctor = + env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V"); + + return ConvertVectorToArrayList<QaAnswer>( + env, results, + [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) { + jstring text = env->NewStringUTF(ans.text.data()); + jobject qa_answer = + env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start, + ans.pos.end, ans.pos.logit); + env->DeleteLocalRef(text); + return qa_answer; + }); +} + +} // namespace diff --git a/tensorflow_lite_support/java/src/native/task/vision/BUILD b/tensorflow_lite_support/java/src/native/task/vision/BUILD new file mode 100644 index 00000000..451a50ca --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/vision/BUILD @@ -0,0 +1,59 @@ +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "jni_utils", + srcs = [ + "jni_utils.cc", + ], + hdrs = [ + "jni_utils.h", + ], + deps = [ + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "task_vision_native", + srcs = [ + ":libtask_vision_jni.so", + ], +) + +tflite_jni_binary( + name = "libtask_vision_jni.so", + srcs = [ + "//tensorflow_lite_support/java/src/native/task/vision/classifier:image_classifier_jni.cc", + "//tensorflow_lite_support/java/src/native/task/vision/detector:object_detector_jni.cc", + "//tensorflow_lite_support/java/src/native/task/vision/segmenter:image_segmenter_jni.cc", + ], + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision:image_classifier", + "//tensorflow_lite_support/cc/task/vision:image_segmenter", + "//tensorflow_lite_support/cc/task/vision:object_detector", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD b/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD new file mode 100644 index 00000000..8bddc2ac --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/vision/classifier/BUILD @@ -0,0 +1,35 @@ +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["image_classifier_jni.cc"]) + +cc_library( + name = "image_classifier_native", + srcs = [ + ":libtask_vision_jni.so", + ], +) + +tflite_jni_binary( + name = "libtask_vision_jni.so", + srcs = [ + "image_classifier_jni.cc", + ], + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision:image_classifier", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", + ], +) diff --git a/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc new file mode 100644 index 00000000..2d52f937 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc @@ -0,0 +1,234 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include <memory> +#include <string> + +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/image_classifier.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" +#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" + +namespace { + +using ::tflite::support::StatusOr; +using ::tflite::support::utils::GetMappedFileBuffer; +using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::StringListToVector; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::vision::BoundingBox; +using ::tflite::task::vision::ClassificationResult; +using ::tflite::task::vision::Classifications; +using ::tflite::task::vision::ConvertToCategory; +using ::tflite::task::vision::ConvertToFrameBufferOrientation; +using ::tflite::task::vision::FrameBuffer; +using ::tflite::task::vision::ImageClassifier; +using ::tflite::task::vision::ImageClassifierOptions; + +// Creates an ImageClassifierOptions proto based on the Java class. +ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env, + jobject java_options) { + ImageClassifierOptions proto_options; + jclass java_options_class = env->FindClass( + "org/tensorflow/lite/task/vision/classifier/" + "ImageClassifier$ImageClassifierOptions"); + + jmethodID display_names_locale_id = env->GetMethodID( + java_options_class, "getDisplayNamesLocale", "()Ljava/lang/String;"); + jstring display_names_locale = static_cast<jstring>( + env->CallObjectMethod(java_options, display_names_locale_id)); + const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr); + proto_options.set_display_names_locale(pchars); + env->ReleaseStringUTFChars(display_names_locale, pchars); + + jmethodID max_results_id = + env->GetMethodID(java_options_class, "getMaxResults", "()I"); + jint max_results = env->CallIntMethod(java_options, max_results_id); + proto_options.set_max_results(max_results); + + jmethodID is_score_threshold_set_id = + env->GetMethodID(java_options_class, "getIsScoreThresholdSet", "()Z"); + jboolean is_score_threshold_set = + env->CallBooleanMethod(java_options, is_score_threshold_set_id); + if (is_score_threshold_set) { + jmethodID score_threshold_id = + env->GetMethodID(java_options_class, "getScoreThreshold", "()F"); + jfloat score_threshold = + env->CallFloatMethod(java_options, score_threshold_id); + proto_options.set_score_threshold(score_threshold); + } + + jmethodID allow_list_id = env->GetMethodID( + java_options_class, "getLabelAllowList", "()Ljava/util/List;"); + jobject allow_list = env->CallObjectMethod(java_options, allow_list_id); + auto allow_list_vector = StringListToVector(env, allow_list); + for (const auto& class_name : allow_list_vector) { + proto_options.add_class_name_whitelist(class_name); + } + + jmethodID deny_list_id = env->GetMethodID( + java_options_class, "getLabelDenyList", "()Ljava/util/List;"); + jobject deny_list = env->CallObjectMethod(java_options, deny_list_id); + auto deny_list_vector = StringListToVector(env, deny_list); + for (const auto& class_name : deny_list_vector) { + proto_options.add_class_name_blacklist(class_name); + } + + jmethodID num_threads_id = + env->GetMethodID(java_options_class, "getNumThreads", "()I"); + jint num_threads = env->CallIntMethod(java_options, num_threads_id); + proto_options.set_num_threads(num_threads); + + return proto_options; +} + +jobject ConvertToClassificationResults(JNIEnv* env, + const ClassificationResult& results) { + // jclass and init of Classifications. + jclass classifications_class = env->FindClass( + "org/tensorflow/lite/task/vision/classifier/Classifications"); + jmethodID classifications_create = + env->GetStaticMethodID(classifications_class, "create", + "(Ljava/util/List;I)Lorg/tensorflow/lite/" + "task/vision/classifier/Classifications;"); + + // jclass, init, and add of ArrayList. + jclass array_list_class = env->FindClass("java/util/ArrayList"); + jmethodID array_list_init = + env->GetMethodID(array_list_class, "<init>", "(I)V"); + jmethodID array_list_add_method = + env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); + + jobject classifications_list = + env->NewObject(array_list_class, array_list_init, + static_cast<jint>(results.classifications_size())); + for (int i = 0; i < results.classifications_size(); i++) { + auto classifications = results.classifications(i); + jobject jcategory_list = env->NewObject(array_list_class, array_list_init, + classifications.classes_size()); + for (const auto& classification : classifications.classes()) { + jobject jcategory = ConvertToCategory(env, classification); + env->CallBooleanMethod(jcategory_list, array_list_add_method, jcategory); + + env->DeleteLocalRef(jcategory); + } + jobject jclassifications = env->CallStaticObjectMethod( + classifications_class, classifications_create, jcategory_list, + classifications.head_index()); + env->CallBooleanMethod(classifications_list, array_list_add_method, + jclassifications); + + env->DeleteLocalRef(jcategory_list); + env->DeleteLocalRef(jclassifications); + } + return classifications_list; +} + +jlong CreateImageClassifierFromOptions(JNIEnv* env, + const ImageClassifierOptions& options) { + StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = + ImageClassifier::CreateFromOptions(options); + if (image_classifier_or.ok()) { + // Deletion is handled at deinitJni time. + return reinterpret_cast<jlong>(image_classifier_or->release()); + } else { + ThrowException(env, kAssertionError, + "Error occurred when initializing ImageClassifier: %s", + image_classifier_or.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_deinitJni( + JNIEnv* env, jobject thiz, jlong native_handle) { + delete reinterpret_cast<ImageClassifier*>(native_handle); +} + +// Creates an ImageClassifier instance from the model file descriptor. +// file_descriptor_length and file_descriptor_offset are optional. Non-possitive +// values will be ignored. +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithModelFdAndOptions( + JNIEnv* env, jclass thiz, jint file_descriptor, + jlong file_descriptor_length, jlong file_descriptor_offset, + jobject java_options) { + ImageClassifierOptions proto_options = + ConvertToProtoOptions(env, java_options); + auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata() + ->mutable_file_descriptor_meta(); + file_descriptor_meta->set_fd(file_descriptor); + if (file_descriptor_length > 0) { + file_descriptor_meta->set_length(file_descriptor_length); + } + if (file_descriptor_offset > 0) { + file_descriptor_meta->set_offset(file_descriptor_offset); + } + return CreateImageClassifierFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_initJniWithByteBuffer( + JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) { + ImageClassifierOptions proto_options = + ConvertToProtoOptions(env, java_options); + // External proto generated header does not overload `set_file_content` with + // string_view, therefore GetMappedFileBuffer does not apply here. + // Creating a std::string will cause one extra copying of data. Thus, the + // most efficient way here is to set file_content using char* and its size. + proto_options.mutable_model_file_with_metadata()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + return CreateImageClassifierFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_task_vision_classifier_ImageClassifier_classifyNative( + JNIEnv* env, jclass thiz, jlong native_handle, jobject image_byte_buffer, + jint width, jint height, jintArray jroi, jint jorientation) { + auto* classifier = reinterpret_cast<ImageClassifier*>(native_handle); + auto image = GetMappedFileBuffer(env, image_byte_buffer); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + reinterpret_cast<const uint8*>(image.data()), + FrameBuffer::Dimension{width, height}, + ConvertToFrameBufferOrientation(env, jorientation)); + + int* roi_array = env->GetIntArrayElements(jroi, 0); + BoundingBox roi; + roi.set_origin_x(roi_array[0]); + roi.set_origin_y(roi_array[1]); + roi.set_width(roi_array[2]); + roi.set_height(roi_array[3]); + env->ReleaseIntArrayElements(jroi, roi_array, 0); + + auto results_or = classifier->Classify(*frame_buffer, roi); + if (results_or.ok()) { + return ConvertToClassificationResults(env, results_or.value()); + } else { + ThrowException(env, kAssertionError, + "Error occurred when classifying the image: %s", + results_or.status().message().data()); + return nullptr; + } +} +} // namespace diff --git a/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD b/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD new file mode 100644 index 00000000..5abd3f1a --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/vision/detector/BUILD @@ -0,0 +1,36 @@ +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["object_detector_jni.cc"]) + +cc_library( + name = "object_detector_native", + srcs = [ + ":libtask_vision_jni.so", + ], +) + +tflite_jni_binary( + name = "libtask_vision_jni.so", + srcs = [ + "object_detector_jni.cc", + ], + linkscript = "//tensorflow_lite_support/java:default_version_script.lds", + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision:object_detector", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc b/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc new file mode 100644 index 00000000..016b0bfd --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc @@ -0,0 +1,228 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/object_detector.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" +#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" + +namespace { + +using ::tflite::support::StatusOr; +using ::tflite::support::utils::GetMappedFileBuffer; +using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::StringListToVector; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::vision::BoundingBox; +using ::tflite::task::vision::ConvertToCategory; +using ::tflite::task::vision::ConvertToFrameBufferOrientation; +using ::tflite::task::vision::DetectionResult; +using ::tflite::task::vision::FrameBuffer; +using ::tflite::task::vision::ObjectDetector; +using ::tflite::task::vision::ObjectDetectorOptions; + +// Creates an ObjectDetectorOptions proto based on the Java class. +ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options) { + ObjectDetectorOptions proto_options; + jclass java_options_class = env->FindClass( + "org/tensorflow/lite/task/vision/detector/" + "ObjectDetector$ObjectDetectorOptions"); + + jmethodID display_names_locale_id = env->GetMethodID( + java_options_class, "getDisplayNamesLocale", "()Ljava/lang/String;"); + jstring display_names_locale = static_cast<jstring>( + env->CallObjectMethod(java_options, display_names_locale_id)); + const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr); + proto_options.set_display_names_locale(pchars); + env->ReleaseStringUTFChars(display_names_locale, pchars); + + jmethodID max_results_id = + env->GetMethodID(java_options_class, "getMaxResults", "()I"); + jint max_results = env->CallIntMethod(java_options, max_results_id); + proto_options.set_max_results(max_results); + + jmethodID is_score_threshold_set_id = + env->GetMethodID(java_options_class, "getIsScoreThresholdSet", "()Z"); + jboolean is_score_threshold_set = + env->CallBooleanMethod(java_options, is_score_threshold_set_id); + if (is_score_threshold_set) { + jmethodID score_threshold_id = + env->GetMethodID(java_options_class, "getScoreThreshold", "()F"); + jfloat score_threshold = + env->CallFloatMethod(java_options, score_threshold_id); + proto_options.set_score_threshold(score_threshold); + } + + jmethodID allow_list_id = env->GetMethodID( + java_options_class, "getLabelAllowList", "()Ljava/util/List;"); + jobject allow_list = env->CallObjectMethod(java_options, allow_list_id); + std::vector<std::string> allow_list_vector = + StringListToVector(env, allow_list); + for (const auto& class_name : allow_list_vector) { + proto_options.add_class_name_whitelist(class_name); + } + + jmethodID deny_list_id = env->GetMethodID( + java_options_class, "getLabelDenyList", "()Ljava/util/List;"); + jobject deny_list = env->CallObjectMethod(java_options, deny_list_id); + auto deny_list_vector = StringListToVector(env, deny_list); + for (const auto& class_name : deny_list_vector) { + proto_options.add_class_name_blacklist(class_name); + } + + jmethodID num_threads_id = + env->GetMethodID(java_options_class, "getNumThreads", "()I"); + jint num_threads = env->CallIntMethod(java_options, num_threads_id); + proto_options.set_num_threads(num_threads); + + return proto_options; +} + +jobject ConvertToDetectionResults(JNIEnv* env, const DetectionResult& results) { + // jclass and init of Detection. + jclass detection_class = + env->FindClass("org/tensorflow/lite/task/vision/detector/Detection"); + jmethodID detection_create = env->GetStaticMethodID( + detection_class, "create", + "(Landroid/graphics/RectF;Ljava/util/List;)Lorg/tensorflow/lite/" + "task/vision/detector/Detection;"); + + // jclass, init, and add of ArrayList. + jclass array_list_class = env->FindClass("java/util/ArrayList"); + jmethodID array_list_init = + env->GetMethodID(array_list_class, "<init>", "(I)V"); + jmethodID array_list_add_method = + env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); + + // jclass, init of RectF. + jclass rectf_class = env->FindClass("android/graphics/RectF"); + jmethodID rectf_init = env->GetMethodID(rectf_class, "<init>", "(FFFF)V"); + + jobject detections_list = + env->NewObject(array_list_class, array_list_init, + static_cast<jint>(results.detections_size())); + + for (const auto& detection : results.detections()) { + // Create the category list. + jobject category_list = env->NewObject(array_list_class, array_list_init, + detection.classes_size()); + for (const auto& classification : detection.classes()) { + jobject jcategory = ConvertToCategory(env, classification); + env->CallBooleanMethod(category_list, array_list_add_method, jcategory); + } + + // Create the bounding box object. + const BoundingBox& bounding_box = detection.bounding_box(); + float left = static_cast<float>(bounding_box.origin_x()); + float top = static_cast<float>(bounding_box.origin_y()); + float right = static_cast<float>(left + bounding_box.width()); + float bottom = static_cast<float>(top + bounding_box.height()); + jobject jbounding_box = + env->NewObject(rectf_class, rectf_init, left, top, right, bottom); + + // Create the java Detection object. + jobject jdetection = env->CallStaticObjectMethod( + detection_class, detection_create, jbounding_box, category_list); + env->CallBooleanMethod(detections_list, array_list_add_method, jdetection); + } + return detections_list; +} + +jlong CreateObjectDetectorFromOptions(JNIEnv* env, + const ObjectDetectorOptions& options) { + StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = + ObjectDetector::CreateFromOptions(options); + if (object_detector_or.ok()) { + return reinterpret_cast<jlong>(object_detector_or->release()); + } else { + ThrowException(env, kAssertionError, + "Error occurred when initializing ObjectDetector: %s", + object_detector_or.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_deinitJni( + JNIEnv* env, jobject thiz, jlong native_handle) { + delete reinterpret_cast<ObjectDetector*>(native_handle); +} + +// Creates an ObjectDetector instance from the model file descriptor. +// file_descriptor_length and file_descriptor_offset are optional. Non-possitive +// values will be ignored. +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdAndOptions( + JNIEnv* env, jclass thiz, jint file_descriptor, + jlong file_descriptor_length, jlong file_descriptor_offset, + jobject java_options) { + ObjectDetectorOptions proto_options = + ConvertToProtoOptions(env, java_options); + auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata() + ->mutable_file_descriptor_meta(); + file_descriptor_meta->set_fd(file_descriptor); + if (file_descriptor_length > 0) { + file_descriptor_meta->set_length(file_descriptor_length); + } + if (file_descriptor_offset > 0) { + file_descriptor_meta->set_offset(file_descriptor_offset); + } + return CreateObjectDetectorFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuffer( + JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) { + ObjectDetectorOptions proto_options = + ConvertToProtoOptions(env, java_options); + proto_options.mutable_model_file_with_metadata()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + return CreateObjectDetectorFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jobject JNICALL +Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_detectNative( + JNIEnv* env, jclass thiz, jlong native_handle, jobject image_byte_buffer, + jint width, jint height, jint jorientation) { + auto* detector = reinterpret_cast<ObjectDetector*>(native_handle); + absl::string_view image = GetMappedFileBuffer(env, image_byte_buffer); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + reinterpret_cast<const uint8*>(image.data()), + FrameBuffer::Dimension{width, height}, + ConvertToFrameBufferOrientation(env, jorientation)); + auto results_or = detector->Detect(*frame_buffer); + if (results_or.ok()) { + return ConvertToDetectionResults(env, results_or.value()); + } else { + ThrowException(env, kAssertionError, + "Error occurred when detecting the image: %s", + results_or.status().message().data()); + return nullptr; + } +} +} // namespace diff --git a/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc b/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc new file mode 100644 index 00000000..af5dad96 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/vision/jni_utils.cc @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" + +namespace tflite { +namespace task { +namespace vision { + +using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::ThrowException; + +constexpr char kCategoryClassName[] = + "org/tensorflow/lite/support/label/Category"; +constexpr char kStringClassName[] = "Ljava/lang/String;"; +constexpr char kEmptyString[] = ""; + +jobject ConvertToCategory(JNIEnv* env, const Class& classification) { + // jclass and init of Category. + jclass category_class = env->FindClass(kCategoryClassName); + jmethodID category_create = env->GetStaticMethodID( + category_class, "create", + absl::StrCat("(", kStringClassName, kStringClassName, "F)L", + kCategoryClassName, ";") + .c_str()); + + std::string label_string = classification.has_class_name() + ? classification.class_name() + : std::to_string(classification.index()); + jstring label = env->NewStringUTF(label_string.c_str()); + std::string display_name_string = classification.has_display_name() + ? classification.display_name() + : kEmptyString; + jstring display_name = env->NewStringUTF(display_name_string.c_str()); + jobject jcategory = + env->CallStaticObjectMethod(category_class, category_create, label, + display_name, classification.score()); + return jcategory; +} + +FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env, + jint jorientation) { + switch (jorientation) { + case 0: + return FrameBuffer::Orientation::kTopLeft; + case 1: + return FrameBuffer::Orientation::kTopRight; + case 2: + return FrameBuffer::Orientation::kBottomRight; + case 3: + return FrameBuffer::Orientation::kBottomLeft; + case 4: + return FrameBuffer::Orientation::kLeftTop; + case 5: + return FrameBuffer::Orientation::kRightTop; + case 6: + return FrameBuffer::Orientation::kRightBottom; + case 7: + return FrameBuffer::Orientation::kLeftBottom; + } + // Should never happen. + ThrowException(env, kAssertionError, + "The FrameBuffer Orientation type is unsupported: %d", + jorientation); + return FrameBuffer::Orientation::kTopLeft; +} + +} // namespace vision +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h b/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h new file mode 100644 index 00000000..7cb63f31 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/vision/jni_utils.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TASK_VISION_JNI_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TASK_VISION_JNI_UTILS_H_ + +#include <jni.h> + +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" + +namespace tflite { +namespace task { +namespace vision { + +// Creates a Java Category object based on Class. +jobject ConvertToCategory(JNIEnv* env, const Class& classification); + +FrameBuffer::Orientation ConvertToFrameBufferOrientation(JNIEnv* env, + jint jorientation); + +} // namespace vision +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_JAVA_SRC_NATIVE_TASK_VISION_JNI_UTILS_H_ diff --git a/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD b/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD new file mode 100644 index 00000000..23d601df --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD @@ -0,0 +1,34 @@ +load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["image_segmenter_jni.cc"]) + +cc_library( + name = "image_segmenter_native", + srcs = [ + ":libtask_vision_jni.so", + ], +) + +tflite_jni_binary( + name = "libtask_vision_jni.so", + srcs = [ + "image_segmenter_jni.cc", + ], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision:image_segmenter", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/cc/utils:jni_utils", + "//tensorflow_lite_support/java/jni", + "//tensorflow_lite_support/java/src/native/task/vision:jni_utils", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc b/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc new file mode 100644 index 00000000..1f6a2dc3 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc @@ -0,0 +1,238 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <jni.h> + +#include <memory> +#include <string> + +#include "absl/strings/str_cat.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/cc/utils/jni_utils.h" +#include "tensorflow_lite_support/java/src/native/task/vision/jni_utils.h" + +namespace { + +using ::tflite::support::StatusOr; +using ::tflite::support::utils::CreateByteArray; +using ::tflite::support::utils::GetMappedFileBuffer; +using ::tflite::support::utils::kAssertionError; +using ::tflite::support::utils::kIllegalArgumentException; +using ::tflite::support::utils::kInvalidPointer; +using ::tflite::support::utils::ThrowException; +using ::tflite::task::vision::ConvertToFrameBufferOrientation; +using ::tflite::task::vision::FrameBuffer; +using ::tflite::task::vision::ImageSegmenter; +using ::tflite::task::vision::ImageSegmenterOptions; +using ::tflite::task::vision::Segmentation; +using ::tflite::task::vision::SegmentationResult; + +constexpr char kArrayListClassNameNoSig[] = "java/util/ArrayList"; +constexpr char kObjectClassName[] = "Ljava/lang/Object;"; +constexpr char kColorClassName[] = "Landroid/graphics/Color;"; +constexpr char kColorClassNameNoSig[] = "android/graphics/Color"; +constexpr char kColoredLabelClassName[] = + "Lorg/tensorflow/lite/task/vision/segmenter/ColoredLabel;"; +constexpr char kColoredLabelClassNameNoSig[] = + "org/tensorflow/lite/task/vision/segmenter/ColoredLabel"; +constexpr char kStringClassName[] = "Ljava/lang/String;"; +constexpr int kOutputTypeCategoryMask = 0; +constexpr int kOutputTypeConfidenceMask = 1; + +// Creates an ImageSegmenterOptions proto based on the Java class. +ImageSegmenterOptions ConvertToProtoOptions(JNIEnv* env, + jstring display_names_locale, + jint output_type, + jint num_threads) { + ImageSegmenterOptions proto_options; + + const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr); + proto_options.set_display_names_locale(pchars); + env->ReleaseStringUTFChars(display_names_locale, pchars); + + switch (output_type) { + case kOutputTypeCategoryMask: + proto_options.set_output_type(ImageSegmenterOptions::CATEGORY_MASK); + break; + case kOutputTypeConfidenceMask: + proto_options.set_output_type(ImageSegmenterOptions::CONFIDENCE_MASK); + break; + default: + // Should never happen. + ThrowException(env, kIllegalArgumentException, + "Unsupported output type: %d", output_type); + } + + proto_options.set_num_threads(num_threads); + + return proto_options; +} + +void ConvertToSegmentationResults(JNIEnv* env, + const SegmentationResult& results, + jobject jmask_buffers, jintArray jmask_shape, + jobject jcolored_labels) { + if (results.segmentation_size() != 1) { + // Should never happen. + ThrowException( + env, kAssertionError, + "ImageSegmenter only supports one segmentation result, getting %d", + results.segmentation_size()); + } + + const Segmentation& segmentation = results.segmentation(0); + + // Get the shape from the C++ Segmentation results. + int shape_array[2] = {segmentation.height(), segmentation.width()}; + env->SetIntArrayRegion(jmask_shape, 0, 2, shape_array); + + // jclass, init, and add of ArrayList. + jclass array_list_class = env->FindClass(kArrayListClassNameNoSig); + jmethodID array_list_add_method = + env->GetMethodID(array_list_class, "add", + absl::StrCat("(", kObjectClassName, ")Z").c_str()); + + // Convert the masks into ByteBuffer list. + int num_pixels = segmentation.height() * segmentation.width(); + if (segmentation.has_category_mask()) { + jbyteArray byte_array = CreateByteArray( + env, + reinterpret_cast<const jbyte*>(segmentation.category_mask().data()), + num_pixels * sizeof(uint8)); + env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array); + env->DeleteLocalRef(byte_array); + } else { + for (const auto& confidence_mask : + segmentation.confidence_masks().confidence_mask()) { + jbyteArray byte_array = CreateByteArray( + env, reinterpret_cast<const jbyte*>(confidence_mask.value().data()), + num_pixels * sizeof(float)); + env->CallBooleanMethod(jmask_buffers, array_list_add_method, byte_array); + env->DeleteLocalRef(byte_array); + } + } + + // Convert colored labels from the C++ object to the Java object. + jclass color_class = env->FindClass(kColorClassNameNoSig); + jmethodID color_rgb_method = + env->GetStaticMethodID(color_class, "rgb", "(III)I"); + jclass colored_label_class = env->FindClass(kColoredLabelClassNameNoSig); + jmethodID colored_label_create_method = env->GetStaticMethodID( + colored_label_class, "create", + absl::StrCat("(", kStringClassName, kStringClassName, "I)", + kColoredLabelClassName) + .c_str()); + + for (const auto& colored_label : segmentation.colored_labels()) { + jstring label = env->NewStringUTF(colored_label.class_name().c_str()); + jstring display_name = + env->NewStringUTF(colored_label.display_name().c_str()); + jint rgb = env->CallStaticIntMethod(color_class, color_rgb_method, + colored_label.r(), colored_label.g(), + colored_label.b()); + jobject jcolored_label = env->CallStaticObjectMethod( + colored_label_class, colored_label_create_method, label, display_name, + rgb); + env->CallBooleanMethod(jcolored_labels, array_list_add_method, + jcolored_label); + + env->DeleteLocalRef(label); + env->DeleteLocalRef(display_name); + env->DeleteLocalRef(jcolored_label); + } +} + +jlong CreateImageClassifierFromOptions(JNIEnv* env, + const ImageSegmenterOptions& options) { + StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = + ImageSegmenter::CreateFromOptions(options); + if (image_segmenter_or.ok()) { + return reinterpret_cast<jlong>(image_segmenter_or->release()); + } else { + ThrowException(env, kAssertionError, + "Error occurred when initializing ImageSegmenter: %s", + image_segmenter_or.status().message().data()); + return kInvalidPointer; + } +} + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_deinitJni( + JNIEnv* env, jobject thiz, jlong native_handle) { + delete reinterpret_cast<ImageSegmenter*>(native_handle); +} + +// Creates an ImageSegmenter instance from the model file descriptor. +// file_descriptor_length and file_descriptor_offset are optional. Non-possitive +// values will be ignored. +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions( + JNIEnv* env, jclass thiz, jint file_descriptor, + jlong file_descriptor_length, jlong file_descriptor_offset, + jstring display_names_locale, jint output_type, jint num_threads) { + ImageSegmenterOptions proto_options = ConvertToProtoOptions( + env, display_names_locale, output_type, num_threads); + auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata() + ->mutable_file_descriptor_meta(); + file_descriptor_meta->set_fd(file_descriptor); + if (file_descriptor_length > 0) { + file_descriptor_meta->set_length(file_descriptor_length); + } + if (file_descriptor_offset > 0) { + file_descriptor_meta->set_offset(file_descriptor_offset); + } + return CreateImageClassifierFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer( + JNIEnv* env, jclass thiz, jobject model_buffer, + jstring display_names_locale, jint output_type, jint num_threads) { + ImageSegmenterOptions proto_options = ConvertToProtoOptions( + env, display_names_locale, output_type, num_threads); + proto_options.mutable_model_file_with_metadata()->set_file_content( + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)), + static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer))); + return CreateImageClassifierFromOptions(env, proto_options); +} + +extern "C" JNIEXPORT void JNICALL +Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_segmentNative( + JNIEnv* env, jclass thiz, jlong native_handle, jobject jimage_byte_buffer, + jint width, jint height, jobject jmask_buffers, jintArray jmask_shape, + jobject jcolored_labels, jint jorientation) { + auto* segmenter = reinterpret_cast<ImageSegmenter*>(native_handle); + absl::string_view image = GetMappedFileBuffer(env, jimage_byte_buffer); + std::unique_ptr<FrameBuffer> frame_buffer = CreateFromRgbRawBuffer( + reinterpret_cast<const uint8*>(image.data()), + FrameBuffer::Dimension{width, height}, + ConvertToFrameBufferOrientation(env, jorientation)); + auto results_or = segmenter->Segment(*frame_buffer); + if (results_or.ok()) { + ConvertToSegmentationResults(env, results_or.value(), jmask_buffers, + jmask_shape, jcolored_labels); + } else { + ThrowException(env, kAssertionError, + "Error occurred when segmenting the image: %s", + results_or.status().message().data()); + } +} + +} // namespace diff --git a/tensorflow_lite_support/metadata/BUILD b/tensorflow_lite_support/metadata/BUILD new file mode 100644 index 00000000..db69bd25 --- /dev/null +++ b/tensorflow_lite_support/metadata/BUILD @@ -0,0 +1,51 @@ +load("@flatbuffers//:build_defs.bzl", "flatbuffer_android_library", "flatbuffer_cc_library", "flatbuffer_java_library", "flatbuffer_py_library") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["metadata_schema.fbs"]) + +flatbuffer_py_library( + name = "schema_py", + srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"], +) + +# Generic schema for inference on device. +flatbuffer_android_library( + name = "schema_fbs_android", + srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"], + custom_package = "org.tensorflow.lite.schema", +) + +flatbuffer_java_library( + name = "schema_fbs_java", + srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"], + custom_package = "org.tensorflow.lite.schema", +) + +# Generic schema for model metadata. +flatbuffer_cc_library( + name = "metadata_schema_cc", + srcs = ["metadata_schema.fbs"], +) + +flatbuffer_py_library( + name = "metadata_schema_py", + srcs = ["metadata_schema.fbs"], +) + +flatbuffer_java_library( + name = "metadata_schema_java", + srcs = ["metadata_schema.fbs"], + custom_package = "org.tensorflow.lite.support.metadata.schema", +) + +flatbuffer_android_library( + name = "metadata_schema_fbs_android", + srcs = ["metadata_schema.fbs"], + custom_package = "org.tensorflow.lite.support.metadata.schema", +) diff --git a/tensorflow_lite_support/metadata/README.md b/tensorflow_lite_support/metadata/README.md new file mode 100644 index 00000000..ff7d25f2 --- /dev/null +++ b/tensorflow_lite_support/metadata/README.md @@ -0,0 +1,15 @@ +# TensorFlow Lite Metadata and Android wrapper code generator + +Note: Both TensorFlow Lite Metadata and the Android wrapper code generator are +in experimental (beta) phase. + +TensorFlow Lite metadata provides a structured framework for storing metadata +to convey information for both the developer that will utilitised the model and +code generators which can create wrapper around the model. For information on +how to populate model metadata, please refer to the [TensorFlow Lite Metadata +documentation](https://www.tensorflow.org/lite/convert/metadata). + +The first code generator which takes advantage of this metadata format is the +TensorFlow Lite Android Code Generator. For more information on how to use this +generator, please refer to the [TensorFlow Lite Android wrapper code generator +documentation](https://www.tensorflow.org/lite/guide/codegen). diff --git a/tensorflow_lite_support/metadata/build_defs.bzl b/tensorflow_lite_support/metadata/build_defs.bzl new file mode 100644 index 00000000..8bdab125 --- /dev/null +++ b/tensorflow_lite_support/metadata/build_defs.bzl @@ -0,0 +1,43 @@ +"""Build rules to generate metadata schema versions.""" + +METADATA_SCHEMA_FILE = "//tensorflow_lite_support/metadata:metadata_schema.fbs" + +def stamp_metadata_parser_version( + name, + srcs, + outs): + """Stamps the latest metadata parser version into the srcs files. + + Replaces all the occurrences of "{LATEST_METADATA_PARSER_VERSION}" in the + srcs files with the metadata schema version extracted from + METADATA_SCHEMA_FILE and then outputs the generated file into outs, + respectively. The number of srcs files needs to match the number of outs + files. + + Args: + name: Rule name. (required) + srcs: List of source files. (required) + outs: List of output files. (required) + """ + if len(srcs) != len(outs): + fail(("The number of srcs files (%d) does not match that of the outs" + + " files (%d).") % + (len(srcs), len(outs))) + + for i in range(0, len(srcs)): + native.genrule( + name = "%s_file%d" % (name, i), + srcs = [srcs[i]], + outs = [outs[i]], + tools = [METADATA_SCHEMA_FILE], + # Gets the metadata schema version from the file, and stamps it + # into the srcs file. + cmd = "version=$$(sed -n -e '/Schema Semantic version/ s/.*\\: *//p' $(location %s));" % + METADATA_SCHEMA_FILE + + 'sed "s/{LATEST_METADATA_PARSER_VERSION}/$$version/" $< > $@', + ) + + native.filegroup( + name = name, + srcs = outs, + ) diff --git a/tensorflow_lite_support/metadata/cc/BUILD b/tensorflow_lite_support/metadata/cc/BUILD new file mode 100644 index 00000000..ed5bedc0 --- /dev/null +++ b/tensorflow_lite_support/metadata/cc/BUILD @@ -0,0 +1,53 @@ +load("//tensorflow_lite_support/metadata:build_defs.bzl", "stamp_metadata_parser_version") + +package( + default_visibility = ["//tensorflow_lite_support:users"], + licenses = ["notice"], # Apache 2.0 +) + +stamp_metadata_parser_version( + name = "metadata_parser_h", + srcs = ["metadata_parser.h.template"], + outs = ["metadata_parser.h"], +) + +cc_library( + name = "metadata_extractor", + srcs = ["metadata_extractor.cc"], + hdrs = ["metadata_extractor.h"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@flatbuffers", + "@org_libzip//:zip", + ] + select({ + "//tensorflow_lite_support/cc:tflite_use_c_api": ["@org_tensorflow//tensorflow/lite/c:c_api"], + "//conditions:default": ["@org_tensorflow//tensorflow/lite:framework"], + }) + [ + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + "//tensorflow_lite_support/cc:common", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/metadata:metadata_schema_cc", + ], +) + +cc_library( + name = "metadata_version", + srcs = ["metadata_version.cc"], + hdrs = [ + "metadata_version.h", + ":metadata_parser_h", + ], + deps = [ + "//tensorflow_lite_support/metadata:metadata_schema_cc", + "@com_google_absl//absl/strings", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels/internal:compatibility", + "@org_tensorflow//tensorflow/lite/tools:logging", + ], +) diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc new file mode 100644 index 00000000..cf4edaa7 --- /dev/null +++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc @@ -0,0 +1,366 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" + +#include <functional> + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "lib/zip.h" // from @org_libzip +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow_lite_support/cc/common.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +#if TFLITE_USE_C_API +#include "tensorflow/lite/c/c_api.h" +#else +#include "tensorflow/lite/model_builder.h" +#endif + +namespace tflite { +namespace metadata { + +namespace { +constexpr char kMetadataBufferName[] = "TFLITE_METADATA"; + +using ::absl::StatusCode; +using ::flatbuffers::Offset; +using ::flatbuffers::Vector; +using ::tflite::TensorMetadata; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::TfLiteSupportStatus; + +// Helper class that takes a callback function, and invokes it in its +// destructor. +class SimpleCleanUp { + public: + explicit SimpleCleanUp(std::function<void()> callback) + : callback_(std::move(callback)) {} + + ~SimpleCleanUp() { + if (callback_ != nullptr) callback_(); + } + + // Use `std::move(simple_cleanup).Cancel()` to prevent the callback from ever + // executing at all. Once a SimpleCleanUp object has been `std::move(...)`-ed, + // it may not be read from again. + void Cancel() && { callback_ = nullptr; } + + private: + std::function<void()> callback_; +}; + +// Util to get item from src_vector specified by index. +template <typename T> +const T* GetItemFromVector( + const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) { + if (src_vector == nullptr || index < 0 || index >= src_vector->size()) { + return nullptr; + } + return src_vector->Get(index); +} +} // namespace + +/* static */ +tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>> +ModelMetadataExtractor::CreateFromModelBuffer(const char* buffer_data, + size_t buffer_size) { + // Use absl::WrapUnique() to call private constructor: + // https://abseil.io/tips/126. + std::unique_ptr<ModelMetadataExtractor> extractor = + absl::WrapUnique(new ModelMetadataExtractor()); + RETURN_IF_ERROR(extractor->InitFromModelBuffer(buffer_data, buffer_size)); + return extractor; +} + +/* static */ +tflite::support::StatusOr<const tflite::ProcessUnit*> +ModelMetadataExtractor::FindFirstProcessUnit( + const tflite::TensorMetadata& tensor_metadata, + tflite::ProcessUnitOptions type) { + const tflite::ProcessUnit* result = nullptr; + if (tensor_metadata.process_units() == nullptr) { + return result; + } + for (const auto process_unit : *tensor_metadata.process_units()) { + if (process_unit->options_type() == type) { + if (result != nullptr) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrCat("Found multiple ProcessUnits with type=", + tflite::EnumNameProcessUnitOptions(type), + ", expected at most one."), + TfLiteSupportStatus::kMetadataInvalidProcessUnitsError); + } + result = process_unit; + } + } + return result; +} + +/* static */ +std::string ModelMetadataExtractor::FindFirstAssociatedFileName( + const tflite::TensorMetadata& tensor_metadata, + tflite::AssociatedFileType type, absl::string_view locale) { + if (tensor_metadata.associated_files() == nullptr) { + return std::string(); + } + for (const auto associated_file : *tensor_metadata.associated_files()) { + if (associated_file->type() != type || associated_file->name() == nullptr) { + continue; + } + if (locale.empty() || (associated_file->locale() != nullptr && + locale == associated_file->locale()->str())) { + return associated_file->name()->str(); + } + } + return std::string(); +} + +absl::Status ModelMetadataExtractor::InitFromModelBuffer( + const char* buffer_data, size_t buffer_size) { + // Rely on the simplest, base flatbuffers verifier. Here is not the place to + // e.g. use an OpResolver: we just want to make sure the buffer is valid to + // access the metadata. + flatbuffers::Verifier verifier = flatbuffers::Verifier( + reinterpret_cast<const uint8_t*>(buffer_data), buffer_size); + if (!tflite::VerifyModelBuffer(verifier)) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "The model is not a valid FlatBuffer buffer.", + TfLiteSupportStatus::kInvalidFlatBufferError); + } + model_ = tflite::GetModel(buffer_data); + if (model_->metadata() == nullptr) { + // Not all models have metadata, which is OK. `GetModelMetadata()` then + // returns nullptr. + return absl::OkStatus(); + } + // Look for the "TFLITE_METADATA" field, if any. + for (int i = 0; i < model_->metadata()->size(); ++i) { + const auto metadata = model_->metadata()->Get(i); + if (metadata->name()->str() != kMetadataBufferName) { + continue; + } + const auto buffer_index = metadata->buffer(); + const auto metadata_buffer = + model_->buffers()->Get(buffer_index)->data()->data(); + if (!tflite::ModelMetadataBufferHasIdentifier(metadata_buffer)) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Invalid metadata schema version: expected %s, got %s", + absl::string_view(tflite::ModelMetadataIdentifier()) + .substr( + 0, flatbuffers::FlatBufferBuilder::kFileIdentifierLength), + // Returned identifier is not null terminated; has to be + // truncated. + absl::string_view( + flatbuffers::GetBufferIdentifier(metadata_buffer)) + .substr( + 0, + flatbuffers::FlatBufferBuilder::kFileIdentifierLength)), + TfLiteSupportStatus::kMetadataInvalidSchemaVersionError); + } + model_metadata_ = tflite::GetModelMetadata(metadata_buffer); + if (model_metadata_ == nullptr) { + return CreateStatusWithPayload(StatusCode::kInternal, + "Expected Model Metadata not to be null."); + } + return ExtractAssociatedFiles(buffer_data, buffer_size); + break; + } + return absl::OkStatus(); +} + +absl::Status ModelMetadataExtractor::ExtractAssociatedFiles( + const char* buffer_data, size_t buffer_size) { + // Setup libzip error reporting. + zip_error_t error; + zip_error_init(&error); + auto zip_error_cleanup = SimpleCleanUp([&error] { zip_error_fini(&error); }); + + // Initialize zip source. + zip_source_t* src = + zip_source_buffer_create(buffer_data, buffer_size, /*freep=*/0, &error); + if (src == nullptr) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrFormat("Can't create zip source from model buffer: %s", + zip_error_strerror(&error)), + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } + auto zip_source_cleanup = SimpleCleanUp([src] { zip_source_free(src); }); + + // Try opening zip source. + zip* zip_archive = zip_open_from_source(src, /*flags=*/0, &error); + if (zip_archive == nullptr) { + // It's OK if it fails: this means there are no associated files with this + // model. + return absl::OkStatus(); + } + auto zip_archive_cleanup = + SimpleCleanUp([zip_archive] { zip_close(zip_archive); }); + // As per the documentation [1] for zip_source_free, it should not be called + // after a successful call to zip_open_from_source. + // + // [1]: https://libzip.org/documentation/zip_source_free.html + std::move(zip_source_cleanup).Cancel(); + + const int num_files = zip_get_num_entries(zip_archive, /*flags=*/0); + for (int index = 0; index < num_files; ++index) { + // Get file stats. + struct zip_stat zip_file_stat; + zip_stat_init(&zip_file_stat); + zip_stat_index(zip_archive, index, /*flags=*/0, &zip_file_stat); + absl::string_view filename = zip_file_stat.name; + const auto unzip_filesize = zip_file_stat.size; + + // Open file. + zip_file* zip_file = zip_fopen_index(zip_archive, index, /*flags=*/0); + if (zip_file == nullptr) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrFormat("Unable to open associated file with name: %s", + zip_file_stat.name), + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } + auto zip_file_cleanup = SimpleCleanUp([zip_file] { zip_fclose(zip_file); }); + + // Unzip file. + char* unzip_buffer = new char[unzip_filesize]; + auto unzip_buffer_cleanup = + SimpleCleanUp([unzip_buffer] { delete[] unzip_buffer; }); + if (zip_fread(zip_file, unzip_buffer, unzip_filesize) != unzip_filesize) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + absl::StrFormat("Unzipping failed for file: %s.", filename), + TfLiteSupportStatus::kMetadataAssociatedFileZipError); + } + + // Copy file contents in map. + associated_files_[filename] = std::string(unzip_buffer, unzip_filesize); + } + return absl::OkStatus(); +} + +tflite::support::StatusOr<absl::string_view> +ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const { + auto it = associated_files_.find(filename); + if (it == associated_files_.end()) { + return CreateStatusWithPayload( + StatusCode::kNotFound, + absl::StrFormat("No associated file with name: %s", filename), + TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError); + } + return it->second; +} + +const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* +ModelMetadataExtractor::GetInputTensorMetadata() const { + if (model_metadata_ == nullptr || + model_metadata_->subgraph_metadata() == nullptr) { + return nullptr; + } + return model_metadata_->subgraph_metadata() + ->Get(kDefaultSubgraphIndex) + ->input_tensor_metadata(); +} + +const tflite::TensorMetadata* ModelMetadataExtractor::GetInputTensorMetadata( + int index) const { + return GetItemFromVector<tflite::TensorMetadata>(GetInputTensorMetadata(), + index); +} + +int ModelMetadataExtractor::GetInputTensorCount() const { + const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* + input_tensor_metadata = GetInputTensorMetadata(); + return input_tensor_metadata == nullptr ? 0 : input_tensor_metadata->size(); +} + +const Vector<Offset<TensorMetadata>>* +ModelMetadataExtractor::GetOutputTensorMetadata() const { + if (model_metadata_ == nullptr || + model_metadata_->subgraph_metadata() == nullptr) { + return nullptr; + } + return model_metadata_->subgraph_metadata() + ->Get(kDefaultSubgraphIndex) + ->output_tensor_metadata(); +} + +const tflite::TensorMetadata* ModelMetadataExtractor::GetOutputTensorMetadata( + int index) const { + return GetItemFromVector<tflite::TensorMetadata>(GetOutputTensorMetadata(), + index); +} + +int ModelMetadataExtractor::GetOutputTensorCount() const { + const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* + output_tensor_metadata = GetOutputTensorMetadata(); + return output_tensor_metadata == nullptr ? 0 : output_tensor_metadata->size(); +} + +const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* +ModelMetadataExtractor::GetInputProcessUnits() const { + if (model_metadata_ == nullptr || + model_metadata_->subgraph_metadata() == nullptr) { + return nullptr; + } + return model_metadata_->subgraph_metadata() + ->Get(kDefaultSubgraphIndex) + ->input_process_units(); +} + +const tflite::ProcessUnit* ModelMetadataExtractor::GetInputProcessUnit( + int index) const { + return GetItemFromVector<tflite::ProcessUnit>(GetInputProcessUnits(), index); +} + +int ModelMetadataExtractor::GetInputProcessUnitsCount() const { + const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* input_process_units = + GetInputProcessUnits(); + return input_process_units == nullptr ? 0 : input_process_units->size(); +} + +const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* +ModelMetadataExtractor::GetOutputProcessUnits() const { + if (model_metadata_ == nullptr || + model_metadata_->subgraph_metadata() == nullptr) { + return nullptr; + } + return model_metadata_->subgraph_metadata() + ->Get(kDefaultSubgraphIndex) + ->output_process_units(); +} + +const tflite::ProcessUnit* ModelMetadataExtractor::GetOutputProcessUnit( + int index) const { + return GetItemFromVector<tflite::ProcessUnit>(GetOutputProcessUnits(), index); +} + +int ModelMetadataExtractor::GetOutputProcessUnitsCount() const { + const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* output_process_units = + GetOutputProcessUnits(); + return output_process_units == nullptr ? 0 : output_process_units->size(); +} + +} // namespace metadata +} // namespace tflite diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/tensorflow_lite_support/metadata/cc/metadata_extractor.h new file mode 100644 index 00000000..8eafe932 --- /dev/null +++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.h @@ -0,0 +1,157 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ +#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace metadata { + +// Extracts and provides easy access to the TFLite ModelMetadata [1] and +// corresponding associated files packed into a TFLite FlatBuffer, if any. +// +// [1]: https://www.tensorflow.org/lite/convert/metadata +class ModelMetadataExtractor { + public: + // Creates a ModelMetadataExtractor from the provided TFLite Model FlatBuffer + // and returns a pointer to the new object. Ownership is transferred to the + // caller. Returns an error if the creation failed, which may happen e.g. if + // the provided buffer is not a valid TFLite FlatBuffer. + // + // Warning: Does not take ownership of the provided buffer, which must outlive + // this object. + // + // It is recommended to obtain and manage the buffer through an + // ExternalFileHandler[1], which is optimized through mmap(2) to avoid having + // to load the entire buffer in memory when provided by path or file + // descriptor. + // + // [1]: + // tensorflow_lite_support/c/task/core/external_file_handler.h + static tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>> + CreateFromModelBuffer(const char* buffer_data, size_t buffer_size); + + // Returns the pointer to the *first* ProcessUnit with the provided type, or + // nullptr if none can be found. An error is returned if multiple + // ProcessUnit-s with the provided type are found. + static tflite::support::StatusOr<const tflite::ProcessUnit*> + FindFirstProcessUnit(const tflite::TensorMetadata& tensor_metadata, + tflite::ProcessUnitOptions type); + + // Returns the name of the *first* associated file with the provided type and + // (optional) locale in the provided TensorMetadata, or an empty string if no + // such associated file can be found (which is not necessarily an error: some + // models have no associated files at all) or its `name` field is unspecified. + // Note: see `GetAssociatedFile` to read the actual file contents. + static std::string FindFirstAssociatedFileName( + const tflite::TensorMetadata& tensor_metadata, + tflite::AssociatedFileType type, + absl::string_view locale = absl::string_view()); + + // Returns a pointer to the extracted TFLite Model Metadata, or nullptr if no + // metadata was present in the Model FlatBuffer provided at creation time. + const tflite::ModelMetadata* GetModelMetadata() const { + return model_metadata_; + } + + // Gets the contents of the associated file with the provided name packed into + // the model metadata. An error is returned if there is no such associated + // file. + tflite::support::StatusOr<absl::string_view> GetAssociatedFile( + const std::string& filename) const; + + // Note: all methods below retrieves metadata of the *first* subgraph as + // default. + + // Gets the metadata for input tensors. + const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* + GetInputTensorMetadata() const; + + // Gets the metadata for the input tensor specified by the given index, or + // nullptr in case there is no metadata or the index is out of range. + const tflite::TensorMetadata* GetInputTensorMetadata(int index) const; + + // Gets the count of input tensors with metadata in the metadata FlatBuffer. + // In particular, 0 is returned when there is no metadata. + int GetInputTensorCount() const; + + // Gets the metadata for output tensors. + const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* + GetOutputTensorMetadata() const; + + // Gets the metadata for the output tensor specified by the given index, or + // nullptr in case there is no metadata or the index is out of range. + const tflite::TensorMetadata* GetOutputTensorMetadata(int index) const; + + // Gets the count of output tensors with metadata in the metadata FlatBuffer. + // In particular, 0 is returned when there is no metadata. + int GetOutputTensorCount() const; + + // Gets the input process units from SubgraphMetadata.input_process_units, + // could be nullptr. + const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>* + GetInputProcessUnits() const; + + // Gets the input process unit specified by the given index, or nullptr in + // case there is no input process unit or the index is out of range. + const tflite::ProcessUnit* GetInputProcessUnit(int index) const; + + // Gets the count of input process units. In particular, 0 is returned when + // there is no input process units. + int GetInputProcessUnitsCount() const; + + // Gets the output process units from SubgraphMetadata.output_process_units, + // could be nullptr. + const flatbuffers::Vector<flatbuffers::Offset<tflite::ProcessUnit>>* + GetOutputProcessUnits() const; + + // Gets the output process unit specified by the given index, or nullptr in + // case there is no output process unit or the index is out of range. + const tflite::ProcessUnit* GetOutputProcessUnit(int index) const; + + // Gets the count of output process units. In particular, 0 is returned when + // there is no output process units. + int GetOutputProcessUnitsCount() const; + + private: + static constexpr int kDefaultSubgraphIndex = 0; + // Private default constructor, called from CreateFromModel(). + ModelMetadataExtractor() = default; + // Initializes the ModelMetadataExtractor from the provided Model FlatBuffer. + absl::Status InitFromModelBuffer(const char* buffer_data, size_t buffer_size); + // Extracts and stores in associated_files_ the associated files (if present) + // packed into the model FlatBuffer data. + absl::Status ExtractAssociatedFiles(const char* buffer_data, + size_t buffer_size); + // Pointer to the TFLite Model object from which to read the ModelMetadata. + const tflite::Model* model_{nullptr}; + // Pointer to the extracted ModelMetadata, if any. + const tflite::ModelMetadata* model_metadata_{nullptr}; + // The files associated with the ModelMetadata, as a map with the filename + // (corresponding to a basename, e.g. "labels.txt") as key and the file + // contents as value. + absl::flat_hash_map<std::string, std::string> associated_files_; +}; + +} // namespace metadata +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ diff --git a/tensorflow_lite_support/metadata/cc/metadata_parser.h.template b/tensorflow_lite_support/metadata/cc/metadata_parser.h.template new file mode 100644 index 00000000..7e260508 --- /dev/null +++ b/tensorflow_lite_support/metadata/cc/metadata_parser.h.template @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_PARSER_H_ +#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_PARSER_H_ + +namespace tflite { +namespace metadata { + +// The version of the metadata parser that this metadata versioning library is +// depending on. +inline constexpr char kMatadataParserVersion[] = "{LATEST_METADATA_PARSER_VERSION}"; + +} // namespace metadata +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_PARSER_H_ diff --git a/tensorflow_lite_support/metadata/cc/metadata_version.cc b/tensorflow_lite_support/metadata/cc/metadata_version.cc new file mode 100644 index 00000000..7679f6c4 --- /dev/null +++ b/tensorflow_lite_support/metadata/cc/metadata_version.cc @@ -0,0 +1,302 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/metadata/cc/metadata_version.h" + +#include <stddef.h> +#include <stdint.h> + +#include <array> +#include <ostream> +#include <string> +#include <vector> + +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/tools/logging.h" +#include "tensorflow_lite_support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace metadata { +namespace { + +// Members that are added to the metadata schema after the initial version +// of 1.0.0. +enum class SchemaMembers { + kAssociatedFileTypeVocabulary = 0, + kSubGraphMetadataInputProcessUnits = 1, + kSubGraphMetadataOutputProcessUnits = 2, + kProcessUnitOptionsBertTokenizerOptions = 3, + kProcessUnitOptionsSentencePieceTokenizerOptions = 4, + kSubGraphMetadataInputTensorGroups = 5, + kSubGraphMetadataOutputTensorGroups = 6, + kProcessUnitOptionsRegexTokenizerOptions = 7, +}; + +// Helper class to compare semantic versions in terms of three integers, major, +// minor, and patch. +class Version { + public: + explicit Version(int major, int minor = 0, int patch = 0) + : version_({major, minor, patch}) {} + + explicit Version(const std::string& version) { + const std::vector<std::string> vec = absl::StrSplit(version, '.'); + // The version string should always be less than four numbers. + TFLITE_DCHECK(vec.size() <= kElementNumber && !vec.empty()); + version_[0] = std::stoi(vec[0]); + version_[1] = vec.size() > 1 ? std::stoi(vec[1]) : 0; + version_[2] = vec.size() > 2 ? std::stoi(vec[2]) : 0; + } + + // Compares two semantic version numbers. + // + // Example results when comparing two versions strings: + // "1.9" precedes "1.14"; + // "1.14" precedes "1.14.1"; + // "1.14" and "1.14.0" are equal. + // + // Returns the value 0 if the two versions are equal; a value less than 0 if + // *this precedes v; a value greater than 0 if v precedes *this. + int Compare(const Version& v) { + for (int i = 0; i < kElementNumber; ++i) { + if (version_[i] != v.version_[i]) { + return version_[i] < v.version_[i] ? -1 : 1; + } + } + return 0; + } + + // Converts version_ into a version string. + std::string ToString() { return absl::StrJoin(version_, "."); } + + private: + static constexpr int kElementNumber = 3; + std::array<int, kElementNumber> version_; +}; + +Version GetMemberVersion(SchemaMembers member) { + switch (member) { + case SchemaMembers::kAssociatedFileTypeVocabulary: + return Version(1, 0, 1); + case SchemaMembers::kSubGraphMetadataInputProcessUnits: + return Version(1, 1, 0); + case SchemaMembers::kSubGraphMetadataOutputProcessUnits: + return Version(1, 1, 0); + case SchemaMembers::kProcessUnitOptionsBertTokenizerOptions: + return Version(1, 1, 0); + case SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions: + return Version(1, 1, 0); + case SchemaMembers::kSubGraphMetadataInputTensorGroups: + return Version(1, 2, 0); + case SchemaMembers::kSubGraphMetadataOutputTensorGroups: + return Version(1, 2, 0); + case SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions: + return Version(1, 2, 1); + default: + // Should never happen. + TFLITE_LOG(FATAL) << "Unsupported schema member: " + << static_cast<int>(member); + } + // Should never happen. + return Version(0, 0, 0); +} + +// Updates min_version if it precedes the new_version. +inline void UpdateMinimumVersion(const Version& new_version, + Version* min_version) { + if (min_version->Compare(new_version) < 0) { + *min_version = new_version; + } +} + +template <typename T> +void UpdateMinimumVersionForTable(const T* table, Version* min_version); + +template <typename T> +void UpdateMinimumVersionForArray( + const flatbuffers::Vector<flatbuffers::Offset<T>>* array, + Version* min_version) { + if (array == nullptr) return; + + for (int i = 0; i < array->size(); ++i) { + UpdateMinimumVersionForTable<T>(array->Get(i), min_version); + } +} + +template <> +void UpdateMinimumVersionForTable<tflite::AssociatedFile>( + const tflite::AssociatedFile* table, Version* min_version) { + if (table == nullptr) return; + + if (table->type() == AssociatedFileType_VOCABULARY) { + UpdateMinimumVersion( + GetMemberVersion(SchemaMembers::kAssociatedFileTypeVocabulary), + min_version); + } +} + +template <> +void UpdateMinimumVersionForTable<tflite::ProcessUnit>( + const tflite::ProcessUnit* table, Version* min_version) { + if (table == nullptr) return; + + tflite::ProcessUnitOptions process_unit_type = table->options_type(); + if (process_unit_type == ProcessUnitOptions_BertTokenizerOptions) { + UpdateMinimumVersion( + GetMemberVersion( + SchemaMembers::kProcessUnitOptionsBertTokenizerOptions), + min_version); + } + if (process_unit_type == ProcessUnitOptions_SentencePieceTokenizerOptions) { + UpdateMinimumVersion( + GetMemberVersion( + SchemaMembers::kProcessUnitOptionsSentencePieceTokenizerOptions), + min_version); + } + if (process_unit_type == ProcessUnitOptions_RegexTokenizerOptions) { + UpdateMinimumVersion( + GetMemberVersion( + SchemaMembers::kProcessUnitOptionsRegexTokenizerOptions), + min_version); + } +} + +template <> +void UpdateMinimumVersionForTable<tflite::TensorMetadata>( + const tflite::TensorMetadata* table, Version* min_version) { + if (table == nullptr) return; + + // Checks the associated_files field. + UpdateMinimumVersionForArray<tflite::AssociatedFile>( + table->associated_files(), min_version); + + // Checks the process_units field. + UpdateMinimumVersionForArray<tflite::ProcessUnit>(table->process_units(), + min_version); +} + +template <> +void UpdateMinimumVersionForTable<tflite::SubGraphMetadata>( + const tflite::SubGraphMetadata* table, Version* min_version) { + if (table == nullptr) return; + + // Checks in the input/output metadata arrays. + UpdateMinimumVersionForArray<tflite::TensorMetadata>( + table->input_tensor_metadata(), min_version); + UpdateMinimumVersionForArray<tflite::TensorMetadata>( + table->output_tensor_metadata(), min_version); + + // Checks the associated_files field. + UpdateMinimumVersionForArray<tflite::AssociatedFile>( + table->associated_files(), min_version); + + // Checks for the input_process_units field. + if (table->input_process_units() != nullptr) { + UpdateMinimumVersion( + GetMemberVersion(SchemaMembers::kSubGraphMetadataInputProcessUnits), + min_version); + UpdateMinimumVersionForArray<tflite::ProcessUnit>( + table->input_process_units(), min_version); + } + + // Checks for the output_process_units field. + if (table->output_process_units() != nullptr) { + UpdateMinimumVersion( + GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputProcessUnits), + min_version); + UpdateMinimumVersionForArray<tflite::ProcessUnit>( + table->output_process_units(), min_version); + } + + // Checks for the input_tensor_groups field. + if (table->input_tensor_groups() != nullptr) { + UpdateMinimumVersion( + GetMemberVersion(SchemaMembers::kSubGraphMetadataInputTensorGroups), + min_version); + } + + // Checks for the output_tensor_groups field. + if (table->output_tensor_groups() != nullptr) { + UpdateMinimumVersion( + GetMemberVersion(SchemaMembers::kSubGraphMetadataOutputTensorGroups), + min_version); + } +} + +template <> +void UpdateMinimumVersionForTable<tflite::ModelMetadata>( + const tflite::ModelMetadata* table, Version* min_version) { + if (table == nullptr) { + // Should never happen, because VerifyModelMetadataBuffer has verified it. + TFLITE_LOG(FATAL) << "The ModelMetadata object is null."; + return; + } + + // Checks the subgraph_metadata field. + if (table->subgraph_metadata() != nullptr) { + for (int i = 0; i < table->subgraph_metadata()->size(); ++i) { + UpdateMinimumVersionForTable<tflite::SubGraphMetadata>( + table->subgraph_metadata()->Get(i), min_version); + } + } + + // Checks the associated_files field. + UpdateMinimumVersionForArray<tflite::AssociatedFile>( + table->associated_files(), min_version); +} + +} // namespace + +TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data, + size_t buffer_size, + std::string* min_version_str) { + flatbuffers::Verifier verifier = + flatbuffers::Verifier(buffer_data, buffer_size); + if (!tflite::VerifyModelMetadataBuffer(verifier)) { + TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer."; + return kTfLiteError; + } + + static constexpr char kDefaultVersion[] = "1.0.0"; + Version min_version = Version(kDefaultVersion); + + // Checks if any member declared after 1.0.0 (such as those in + // SchemaMembers) exists, and updates min_version accordingly. The minimum + // metadata parser version will be the largest version number of all fields + // that has been added to a metadata flatbuffer + const tflite::ModelMetadata* model_metadata = GetModelMetadata(buffer_data); + + // All tables in the metadata schema should have their dedicated + // UpdateMinimumVersionForTable<Foo>() methods, respectively. We'll gradually + // add these methods when new fields show up in later schema versions. + // + // UpdateMinimumVersionForTable<Foo>() takes a const pointer of Foo. The + // pointer can be a nullptr if Foo is not populated into the corresponding + // table of the Flatbuffer object. In this case, + // UpdateMinimumVersionFor<Foo>() will be skipped. An exception is + // UpdateMinimumVersionForModelMetadata(), where ModelMetadata is the root + // table, and it won't be null. + UpdateMinimumVersionForTable<tflite::ModelMetadata>(model_metadata, + &min_version); + + *min_version_str = min_version.ToString(); + return kTfLiteOk; +} + +} // namespace metadata +} // namespace tflite diff --git a/tensorflow_lite_support/metadata/cc/metadata_version.h b/tensorflow_lite_support/metadata/cc/metadata_version.h new file mode 100644 index 00000000..6332aaec --- /dev/null +++ b/tensorflow_lite_support/metadata/cc/metadata_version.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_VERSION_H_ +#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_VERSION_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <string> + +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace metadata { + +// Gets the minimum metadata parser version that can fully understand all fields +// in a given metadata flatbuffer. TFLite Metadata follows Semantic Versioning +// 2.0. Each release version has the form MAJOR.MINOR.PATCH. +TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data, + size_t buffer_size, + std::string* min_version); + +} // namespace metadata +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_VERSION_H_ diff --git a/tensorflow_lite_support/metadata/cc/python/BUILD b/tensorflow_lite_support/metadata/cc/python/BUILD new file mode 100644 index 00000000..34e9a4f9 --- /dev/null +++ b/tensorflow_lite_support/metadata/cc/python/BUILD @@ -0,0 +1,22 @@ +load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = [ + "//tensorflow_lite_support/metadata:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "_pywrap_metadata_version", + srcs = [ + "metadata_version.cc", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_metadata_version", + deps = [ + "//tensorflow_lite_support/metadata/cc:metadata_version", + "@org_tensorflow//tensorflow/lite/c:common", + "@pybind11", + ], +) diff --git a/tensorflow_lite_support/metadata/cc/python/metadata_version.cc b/tensorflow_lite_support/metadata/cc/python/metadata_version.cc new file mode 100644 index 00000000..db3a29e5 --- /dev/null +++ b/tensorflow_lite_support/metadata/cc/python/metadata_version.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/metadata/cc/metadata_version.h" + +#include "pybind11/pybind11.h" +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace metadata { + +PYBIND11_MODULE(_pywrap_metadata_version, m) { + m.doc() = R"pbdoc( + _pywrap_metadata_version + A module that returns the minimum metadata parser version of a given + metadata flatbuffer. + )pbdoc"; + + // Using pybind11 type conversions to convert between Python and native + // C++ types. There are other options to provide access to native Python types + // in C++ and vice versa. See the pybind 11 instrcution [1] for more details. + // Type converstions is recommended by pybind11, though the main downside + // is that a copy of the data must be made on every Python to C++ transition: + // this is needed since the C++ and Python versions of the same type generally + // won’t have the same memory layout. + // + // [1]: https://pybind11.readthedocs.io/en/stable/advanced/cast/index.html + m.def("GetMinimumMetadataParserVersion", + [](const std::string& buffer_data) -> std::string { + std::string min_version; + if (GetMinimumMetadataParserVersion( + reinterpret_cast<const uint8_t*>(buffer_data.c_str()), + buffer_data.length(), &min_version) != kTfLiteOk) { + pybind11::value_error( + "Error occurred when getting the minimum metadata parser " + "version of the metadata flatbuffer."); + } + return min_version; + }); +} + +} // namespace metadata +} // namespace tflite diff --git a/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD b/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD new file mode 100644 index 00000000..d4171bf9 --- /dev/null +++ b/tensorflow_lite_support/metadata/flatbuffers_lib/BUILD @@ -0,0 +1,22 @@ +load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "_pywrap_flatbuffers", + srcs = [ + "flatbuffers_lib.cc", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_flatbuffers", + deps = [ + "@flatbuffers", + "@local_config_python//:python_headers", + "@pybind11", + ], +) diff --git a/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc b/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc new file mode 100644 index 00000000..61857225 --- /dev/null +++ b/tensorflow_lite_support/metadata/flatbuffers_lib/flatbuffers_lib.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/idl.h" // from @flatbuffers +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace tflite { +namespace support { + +PYBIND11_MODULE(_pywrap_flatbuffers, m) { + pybind11::class_<flatbuffers::IDLOptions>(m, "IDLOptions") + .def(pybind11::init<>()) + .def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json); + pybind11::class_<flatbuffers::Parser>(m, "Parser") + .def(pybind11::init<const flatbuffers::IDLOptions&>()) + .def("parse", + [](flatbuffers::Parser* self, const std::string& source) { + return self->Parse(source.c_str()); + }) + .def_readonly("builder", &flatbuffers::Parser::builder_) + .def_readonly("error", &flatbuffers::Parser::error_); + pybind11::class_<flatbuffers::FlatBufferBuilder>(m, "FlatBufferBuilder") + .def("clear", &flatbuffers::FlatBufferBuilder::Clear) + .def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self, + const std::string& contents) { + self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()), + contents.length()); + }); + m.def("generate_text_file", &flatbuffers::GenerateTextFile); + m.def( + "generate_text", + [](const flatbuffers::Parser& parser, + const std::string& buffer) -> std::string { + std::string text; + if (!flatbuffers::GenerateText( + parser, reinterpret_cast<const void*>(buffer.c_str()), &text)) { + return ""; + } + return text; + }); +} + +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/metadata/metadata_schema.fbs b/tensorflow_lite_support/metadata/metadata_schema.fbs new file mode 100644 index 00000000..8faae0a8 --- /dev/null +++ b/tensorflow_lite_support/metadata/metadata_schema.fbs @@ -0,0 +1,686 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tflite; + +// TFLite metadata contains both human readable and machine readable information +// about what the model does and how to use the model. It can be used as a +// README file, which elaborates the details of the model, each input/ouput +// tensor, and each associated file. +// +// An important use case of TFLite metadata is the TFLite codegen tool, which +// automatically generates the model interface based on the properties of the +// model and the tensors. The model interface provides high-level APIs to +// interact with the model, such as preprocessing the input data and running +// inferences. +// +// Entries marked with "<Codegen usage>" are used in TFLite codegen tool to +// generate the model interface. It is recommended to fill in at least those +// enties to boost the codegen performance. + +// The Metadata schema is versioned by the Semantic versioning number, such as +// MAJOR.MINOR.PATCH. It tracks the schema changes according to the rules below: +// * Bump up the MAJOR number when making potentially backwards incompatible +// changes. It must be incremented if the new changes break the backwards +// compatibility. It may also include minor and patch level changes as +// needed. The true backwards compatibility is indicated by the file +// identifier. +// * Bump up the MINOR number when making backwards compatible updates for +// major features, such as supporting new content types or adding new +// processing units. +// * Bump up the PATCH number when making small backwards compatible changes, +// such as adding a new fields or deprecating certain fields (not deleting +// them). +// +// ModelMetadata.min_parser_version indicates the minimum necessary metadata +// parser version to fully understand all fields in a given metadata flatbuffer. +// +// New fields and types will have associated comments with the schema version +// for which they were added. +// +// LINT.IfChange +// Schema Semantic version: 1.2.1 +// LINT.ThenChange(//tensorflow_lite_support/\ +// metadata/java/src/java/org/tensorflow/lite/support/metadata/\ +// MetadataParser.java) + +// This indicates the flatbuffer compatibility. The number will bump up when a +// break change is applied to the schema, such as removing fields or adding new +// fields to the middle of a table. +file_identifier "M001"; + +// History: +// 1.0.1 - Added VOCABULARY type to AssociatedFileType. +// 1.1.0 - Added BertTokenizerOptions to ProcessUnitOptions. +// Added SentencePieceTokenizerOptions to ProcessUnitOptions. +// Added input_process_units to SubGraphMetadata. +// Added output_process_units to SubGraphMetadata. +// 1.2.0 - Added input_tensor_group to SubGraphMetadata. +// Added output_tensor_group to SubGraphMetadata. +// 1.2.1 - Added RegexTokenizerOptions to ProcessUnitOptions. + +// File extension of any written files. +file_extension "tflitemeta"; + +// LINT.IfChange +enum AssociatedFileType : byte { + UNKNOWN = 0, + + // Files such as readme.txt. + DESCRIPTIONS = 1, + + // Contains labels that annotate certain axis of the tensor. For example, + // the label file in image classification. Those labels annotate the + // the output tensor, such that each value in the output tensor is the + // probability of that corresponding category specified by the label. + // + // <Codegen usage>: + // If an output tensor has an associated file as TENSOR_AXIS_LABELS, return + // the output as a mapping between the labels and probability in the model + // interface. + // If multiple files of the same type are present, the first one is used by + // default; additional ones are to be distinguished from one another by their + // specified locale. + TENSOR_AXIS_LABELS = 2, + + // Contains labels that tensor values correspond to. For example, in + // the object detection model, one of the output tensors is the detected + // classes. And each value in the tensor refers to the index of label in the + // category label file. + // + // <Codegen usage>: + // If an output tensor has an associated file as TENSOR_VALUE_LABELS, convert + // the tensor values into labels, and return a list of string as the output. + // If multiple files of the same type are present, the first one is used by + // default; additional ones are to be distinguished from one another by their + // specified locale. + TENSOR_VALUE_LABELS = 3, + + // Contains sigmoid-based score calibration parameters, formatted as CSV. + // Lines contain for each index of an output tensor the scale, slope, offset + // and (optional) min_score parameters to be used for sigmoid fitting (in this + // order and in `strtof`-compatible [1] format). + // A line may be left empty to default calibrated scores for this index to + // default_score. + // In summary, each line should thus contain 0, 3 or 4 comma-separated values. + // + // See documentation for ScoreCalibrationOptions for details. + // + // [1]: https://en.cppreference.com/w/c/string/byte/strtof + TENSOR_AXIS_SCORE_CALIBRATION = 4, + + // Contains a list of unique words (characters separated by "\n" or in lines) + // that help to convert natural language words to embedding vectors. + // Added in: 1.0.1 + VOCABULARY = 5, +} + +table AssociatedFile { + // Name of this file. Need to be exact the same as the name of the actual file + // packed into the TFLite model as a zip file. + // + // <Codegen usage>: + // Locates to the actual file in the TFLite model. + name:string; + + // A description of what the file is. + description:string; + + // Type of the associated file. There may be special pre/post processing for + // some types. For example in image classification, a label file of the output + // will be used to convert object index into string. + // + // <Codegen usage>: + // Determines how to process the corresponding tensor. + type:AssociatedFileType; + + // An optional locale for this associated file (if applicable). It is + // recommended to use an ISO 639-1 letter code (e.g. "en" for English), + // optionally completed by a two letter region code (e.g. "en-US" for US + // English and "en-CA" for Canadian English). + // Leverage this in order to specify e.g multiple label files translated in + // different languages. + locale:string; +} + +// The basic content type for all tensors. +// +// <Codegen usage>: +// Input feature tensors: +// 1. Generates the method to load data from a TensorBuffer. +// 2. Creates the preprocessing logic. The default processing pipeline is: +// [NormalizeOp, QuantizeOp]. +// Output feature tensors: +// 1. Generates the method to return the output data to a TensorBuffer. +// 2. Creates the post-processing logic. The default processing pipeline is: +// [DeQuantizeOp]. +table FeatureProperties { +} + +// The type of color space of an image. +enum ColorSpaceType : byte { + UNKNOWN = 0, + RGB = 1, + GRAYSCALE = 2, +} + +table ImageSize { + width:uint; + height:uint; +} + +// The properties for image tensors. +// +// <Codegen usage>: +// Input image tensors: +// 1. Generates the method to load an image from a TensorImage. +// 2. Creates the preprocessing logic. The default processing pipeline is: +// [ResizeOp, NormalizeOp, QuantizeOp]. +// Output image tensors: +// 1. Generates the method to return the output data to a TensorImage. +// 2. Creates the post-processing logic. The default processing pipeline is: +// [DeQuantizeOp]. +table ImageProperties { + // The color space of the image. + // + // <Codegen usage>: + // Determines how to convert the color space of a given image from users. + color_space:ColorSpaceType; + + // Indicates the default value of image width and height if the tensor shape + // is dynamic. For fixed-size tensor, this size will be consistent with the + // expected size. + default_size:ImageSize; +} + +// The properties for tensors representing bounding boxes. +// +// <Codegen usage>: +// Input image tensors: NA. +// Output image tensors: parses the values into a data stucture that represents +// bounding boxes. For example, in the generated wrapper for Android, it returns +// the output as android.graphics.Rect objects. +enum BoundingBoxType : byte { + UNKNOWN = 0, + // Represents the bounding box by using the combination of boundaries, + // {left, top, right, bottom}. + // The default order is {left, top, right, bottom}. Other orders can be + // indicated by BoundingBoxProperties.index. + BOUNDARIES = 1, + + // Represents the bounding box by using the upper_left corner, width and + // height. + // The default order is {upper_left_x, upper_left_y, width, height}. Other + // orders can be indicated by BoundingBoxProperties.index. + UPPER_LEFT = 2, + + // Represents the bounding box by using the center of the box, width and + // height. The default order is {center_x, center_y, width, height}. Other + // orders can be indicated by BoundingBoxProperties.index. + CENTER = 3, + +} + +enum CoordinateType : byte { + // The coordinates are float values from 0 to 1. + RATIO = 0, + // The coordinates are integers. + PIXEL = 1, +} + +table BoundingBoxProperties { + // Denotes the order of the elements defined in each bounding box type. An + // empty index array represent the default order of each bounding box type. + // For example, to denote the default order of BOUNDARIES, {left, top, right, + // bottom}, the index should be {0, 1, 2, 3}. To denote the order {left, + // right, top, bottom}, the order should be {0, 2, 1, 3}. + // + // The index array can be applied to all bounding box types to adjust the + // order of their corresponding underlying elements. + // + // <Codegen usage>: + // Indicates how to parse the bounding box values. + index:[uint]; + + // <Codegen usage>: + // Indicates how to parse the bounding box values. + type:BoundingBoxType; + + // <Codegen usage>: + // Indicates how to convert the bounding box back to the original image in + // pixels. + coordinate_type:CoordinateType; +} + +union ContentProperties { + FeatureProperties, + ImageProperties, + BoundingBoxProperties, +} + +table ValueRange { + min:int; + max:int; +} + +table Content { + // The properties that the content may have, indicating the type of the + // Content. + // + // <Codegen usage>: + // Indicates how to process the tensor. + content_properties:ContentProperties; + + // The range of dimensions that the content corresponds to. A NULL + // "range" indicates that the content uses up all dimensions, + // except the batch axis if applied. + // + // Here are all the possible situations of how a tensor is composed. + // Case 1: The tensor is a single object, such as an image. + // For example, the input of an image classifier + // (https://www.tensorflow.org/lite/models/image_classification/overview), + // a tensor of shape [1, 224, 224, 3]. Dimensions 1 to 3 correspond to the + // image. Since dimension 0 is a batch axis, which can be ignored, + // "range" can be left as NULL. + // + // Case 2: The tensor contains multiple instances of the same object. + // For example, the output tensor of detected bounding boxes of an object + // detection model + // (https://www.tensorflow.org/lite/models/object_detection/overview). + // The tensor shape is [1, 10, 4]. Here is the what the three dimensions + // represent for: + // dimension 0: the batch axis. + // dimension 1: the 10 objects detected with the highest confidence. + // dimension 2: the bounding boxes of the 10 detected objects. + // The tensor is essentially 10 bounding boxes. In this case, + // "range" should be {min=2; max=2;}. + // + // The output tensor of scores of the above object detection model has shape + // [1, 10], where + // dimension 0: the batch axis; + // dimension 1: the scores of the 10 detected objects. + // Set "range" to the number of dimensions which is {min=2; max=2;} to denote + // that every element in the tensor is an individual content object, i.e. a + // score in this example. + // + // Another example is the pose estimation model + // (https://www.tensorflow.org/lite/models/pose_estimation/overview). + // The output tensor of heatmaps is in the shape of [1, 9, 9, 17]. + // Here is the what the four dimensions represent for: + // dimension 0: the batch axis. + // dimension 1/2: the heatmap image. + // dimension 3: 17 body parts of a person. + // Even though the last axis is body part, the real content of this tensor is + // the heatmap. "range" should be [min=1; max=2]. + // + // Case 3: The tensor contains multiple different objects. (Not supported by + // Content at this point). + // Sometimes a tensor may contain multiple different objects, thus different + // contents. It is very common for regression models. For example, a model + // to predict the fuel efficiency + // (https://www.tensorflow.org/tutorials/keras/regression). + // The input tensor has shape [1, 9], consisting of 9 features, such as + // "Cylinders", "Displacement", "Weight", etc. In this case, dimension 1 + // contains 9 different contents. However, since these sub-dimension objects + // barely need to be specifically processed, their contents are not recorded + // in the metadata. Through, the name of each dimension can be set through + // TensorMetadata.dimension_names. + // + // Note that if it is not case 3, a tensor can only have one content type. + // + // <Codegen usage>: + // Case 1: return a processed single object of certain content type. + // Case 2: return a list of processed objects of certain content type. The + // generated model interface have API to random access those objects from + // the output. + range:ValueRange; +} + +// Parameters that are used when normalizing the tensor. +table NormalizationOptions{ + // mean and std are normalization parameters. Tensor values are normalized + // on a per-channel basis, by the formula + // (x - mean) / std. + // If there is only one value in mean or std, we'll propogate the value to + // all channels. + // + // Quantized models share the same normalization parameters as their + // corresponding float models. For example, an image input tensor may have + // the normalization parameter of + // mean = 127.5f and std = 127.5f. + // The image value will be normalized from [0, 255] to [-1, 1]. + // Then, for quantized models, the image data should be further quantized + // according to the quantization parameters. In the case of uint8, the image + // data will be scaled back to [0, 255], while for int8, the image data will + // be scaled to [-128, 127]. + // + // Both the normalization parameters and quantization parameters can be + // retrieved through the metadata extractor library. + // TODO(b/156644598): add link for the metadata extractor library. + + // Per-channel mean of the possible values used in normalization. + // + // <Codegen usage>: + // Apply normalization to input tensors accordingly. + mean:[float]; + + // Per-channel standard dev. of the possible values used in normalization. + // + // <Codegen usage>: + // Apply normalization to input tensors accordingly. + std:[float]; +} + +// The different possible score transforms to apply to uncalibrated scores +// before applying score calibration. +enum ScoreTransformationType : byte { + // Identity function: g(x) = x. + IDENTITY = 0, + // Log function: g(x) = log(x). + LOG = 1, + // Inverse logistic function: g(x) = log(x) - log(1-x). + INVERSE_LOGISTIC = 2, +} + +// Options to perform score calibration on an output tensor through sigmoid +// functions. One of the main purposes of score calibration is to make scores +// across classes comparable, so that a common threshold can be used for all +// output classes. This is meant for models producing class predictions as +// output, e.g. image classification or detection models. +// +// For each index in the output tensor, this applies: +// * `f(x) = scale / (1 + e^-(slope*g(x)+offset))` if `x > min_score` or if no +// `min_score` has been specified, +// * `f(x) = default_score` otherwise or if no scale, slope and offset have been +// specified. +// Where: +// * scale, slope, offset and (optional) min_score are index-specific parameters +// * g(x) is an index-independent transform among those defined in +// ScoreTransformationType +// * default_score is an index-independent parameter. +// An AssociatedFile with type TANSOR_AXIS_SCORE_CALIBRATION specifying the +// index-specific parameters must be associated with the corresponding +// TensorMetadata for score calibration be applied. +table ScoreCalibrationOptions { + // The function to use for transforming the uncalibrated score before + // applying score calibration. + score_transformation:ScoreTransformationType; + + // The default calibrated score to apply if the uncalibrated score is + // below min_score or if no parameters were specified for a given index. + default_score:float; +} + +// Performs thresholding on output tensor values, in order to filter out +// low-confidence results. +table ScoreThresholdingOptions { + // The recommended global threshold below which results are considered + // low-confidence and should be filtered out. + global_score_threshold:float; +} + +// Performs Bert tokenization as in tf.text.BertTokenizer +// (https://github.com/tensorflow/text/blob/3599f6fcd2b780a2dc413b90fb9315464f10b314/docs/api_docs/python/text/BertTokenizer.md) +// Added in: 1.1.0 +table BertTokenizerOptions { + // The vocabulary files used in the BertTokenizer. + vocab_file:[AssociatedFile]; +} + +// Performs SentencePiece tokenization as in tf.text.SentencepieceTokenizer +// (https://github.com/tensorflow/text/blob/3599f6fcd2b780a2dc413b90fb9315464f10b314/docs/api_docs/python/text/SentencepieceTokenizer.md). +// Added in: 1.1.0 +table SentencePieceTokenizerOptions { + // The SentencePiece model files used in the SentencePieceTokenizer. + sentencePiece_model:[AssociatedFile]; + + // The optional vocabulary model files used in the SentencePieceTokenizer. + vocab_file:[AssociatedFile]; +} + +// Splits strings by the occurrences of delim_regex_pattern and converts the +// tokens into ids. For example, given +// delim_regex_pattern: "\W+", +// string: "Words, words, words.", +// the tokens after split are: "Words", "words", "words", "". +// And then the tokens can be converted into ids according to the vocab_file. +// Added in: 1.2.1 +table RegexTokenizerOptions { + delim_regex_pattern:string; + // The vocabulary files used to convert this tokens into ids. + vocab_file:[AssociatedFile]; +} + +// Options that are used when processing the tensor. +union ProcessUnitOptions { + NormalizationOptions, + ScoreCalibrationOptions, + ScoreThresholdingOptions, + // Added in: 1.1.0 + BertTokenizerOptions, + // Added in: 1.1.0 + SentencePieceTokenizerOptions, + // Added in: 1.2.1 + RegexTokenizerOptions +} + +// A process unit that is used to process the tensor out-of-graph. +table ProcessUnit { + options:ProcessUnitOptions; +} + + +// Statistics to describe a tensor. +table Stats { + // Max and min are not currently used in tflite.support codegen. They mainly + // serve as references for users to better understand the model. They can also + // be used to validate model pre/post processing results. + // If there is only one value in max or min, we'll propogate the value to + // all channels. + + // Per-channel maximum value of the tensor. + max:[float]; + + // Per-channel minimum value of the tensor. + min:[float]; +} + +// Metadata of a group of tensors. It may contain several tensors that will be +// grouped together in codegen. For example, the TFLite object detection model +// example (https://www.tensorflow.org/lite/models/object_detection/overview) +// has four outputs: classes, scores, bounding boxes, and number of detections. +// If the four outputs are bundled together using TensorGroup (for example, +// named as "detection result"), the codegen tool will generate the class, +// `DetectionResult`, which contains the class, score, and bouding box. And the +// outputs of the model will be converted to a list of `DetectionResults` and +// the number of detection. Note that the number of detection is a single +// number, therefore is inappropriate for the list of `DetectionResult`. +// Added in: 1.2.0 +table TensorGroup { + // Name of tensor group. + // + // <codegen usage>: + // Name of the joint class of the tensor group. + name:string; + + // Names of the tensors to group together, corresponding to + // TensorMetadata.name. + // + // <codegen usage>: + // Determines which tensors will be added to this group. All tensors in the + // group should have the same number of elements specified by Content.range. + tensor_names:[string]; +} + +// Detailed information of an input or output tensor. +table TensorMetadata { + // Name of the tensor. + // + // <Codegen usage>: + // The name of this tensor in the generated model interface. + name:string; + + // A description of the tensor. + description:string; + + // A list of names of the dimensions in this tensor. The length of + // dimension_names need to match the number of dimensions in this tensor. + // + // <Codegen usage>: + // The name of each dimension in the generated model interface. See "Case 2" + // in the comments of Content.range. + dimension_names:[string]; + + // The content that represents this tensor. + // + // <Codegen usage>: + // Determines how to process this tensor. See each item in ContentProperties + // for the default process units that will be applied to the tensor. + content:Content; + + // The process units that are used to process the tensor out-of-graph. + // + // <Codegen usage>: + // Contains the parameters of the default processing pipeline for each content + // type, such as the normalization parameters in all content types. See the + // items under ContentProperties for the details of the default processing + // pipeline. + process_units:[ProcessUnit]; + + // The statistics of the tensor values. + stats:Stats; + + // A list of associated files of this tensor. + // + // <Codegen usage>: + // Contains processing parameters of this tensor, such as normalization. + associated_files:[AssociatedFile]; +} + +table SubGraphMetadata { + // Name of the subgraph. + // + // Note that, since TFLite only support one subgraph at this moment, the + // Codegen tool will use the name in ModelMetadata in the generated model + // interface. + name:string; + + // A description explains details about what the subgraph does. + description:string; + + // Metadata of all input tensors used in this subgraph. It matches extactly + // the input tensors specified by `SubGraph.inputs` in the TFLite + // schema.fbs file[2]. The number of `TensorMetadata` in the array should + // equal to the number of indices in `SubGraph.inputs`. + // + // [2]: tensorflow/lite/schema/schema.fbs + // <Codegen usage>: + // Determines how to process the inputs. + input_tensor_metadata:[TensorMetadata]; + + // Metadata of all output tensors used in this subgraph. It matches extactly + // the output tensors specified by `SubGraph.outputs` in the TFLite + // schema.fbs file[2]. The number of `TensorMetadata` in the array should + // equal to the number of indices in `SubGraph.outputs`. + // + // <Codegen usage>: + // Determines how to process the outputs. + output_tensor_metadata:[TensorMetadata]; + + // A list of associated files of this subgraph. + associated_files:[AssociatedFile]; + + // Input process units of the subgraph. Some models may have complex pre and + // post processing logics where the process units do not work on one tensor at + // a time, but in a similar way of a TFLite graph. For example, in the + // MobileBert model (https://www.tensorflow.org/lite/models/bert_qa/overview), + // the inputs are: ids / mask / segment ids; + // the outputs are: end logits / start logits. + // The preprocessing converts the query string and the context string to the + // model inputs, and the post-processing converts the model outputs to the + // answer string. + // Added in: 1.1.0 + input_process_units:[ProcessUnit]; + + // Output process units of the subgraph. + // Added in: 1.1.0 + output_process_units:[ProcessUnit]; + + // Metadata of all input tensor groups used in this subgraph. + // + // <codegen usage>: + // Bundles the corresponding elements of the underlying input tensors together + // into a class, and converts those individual tensors into a list of the + // class objects. + // Added in: 1.2.0 + input_tensor_groups:[TensorGroup]; + + // Metadata of all output tensor groups used in this subgraph. + // + // <codegen usage>: + // Bundles the corresponding elements of the underlying output tensors + // together into a class, and converts those individual tensors into a list of + // the class objects. + // Added in: 1.2.0 + output_tensor_groups:[TensorGroup]; + +} + +table ModelMetadata { + // Name of the model. + // + // <Codegen usage>: + // The name of the model in the generated model interface. + name:string; + + // Model description in schema. + description:string; + + // Version of the model that specified by model creators. + version:string; + + // Noted that, the minimum required TFLite runtime version that the model is + // compatible with, has already been added as a metadata entry in tflite + // schema. We'll decide later if we want to move it here, and keep it with + // other metadata entries. + + // Metadata of all the subgraphs of the model. The 0th is assumed to be the + // main subgraph. + // + // <Codegen usage>: + // Determines how to process the inputs and outputs. + subgraph_metadata:[SubGraphMetadata]; + + // The person who creates this model. + author:string; + + // Licenses that may apply to this model. + license:string; + + // A list of associated files of this model. + associated_files:[AssociatedFile]; + + // The minimum metadata parser version that can fully understand the fields in + // the metadata flatbuffer. The version is effectively the largest version + // number among the versions of all the fields populated and the smallest + // compatible version indicated by the file identifier. + // + // This field is automaticaly populated by the MetadataPopulator when + // the metadata is populated into a TFLite model. + min_parser_version:string; +} +// LINT.ThenChange(//tensorflow_lite_support/\ +// metadata/cc/metadata_version.cc) + +root_type ModelMetadata; diff --git a/tensorflow_lite_support/opensource/opensource_only.files b/tensorflow_lite_support/opensource/opensource_only.files new file mode 100644 index 00000000..be426420 --- /dev/null +++ b/tensorflow_lite_support/opensource/opensource_only.files @@ -0,0 +1,36 @@ +tensorflow_lite_support/custom_ops/kernel/sentencepiece/native.bzl +tensorflow_lite_support/opensource/BUILD +tensorflow_lite_support/opensource/WORKSPACE +tensorflow_lite_support/opensource/cc_build_defs.bzl +tensorflow_lite_support/third_party/android/BUILD +tensorflow_lite_support/third_party/android/android.bzl.tpl +tensorflow_lite_support/third_party/android/android_configure.BUILD.tpl +tensorflow_lite_support/third_party/android/android_configure.bzl +tensorflow_lite_support/third_party/com_google_absl.BUILD +tensorflow_lite_support/third_party/darts_clone.BUILD +tensorflow_lite_support/third_party/fft2d/BUILD +tensorflow_lite_support/third_party/fft2d/LICENSE +tensorflow_lite_support/third_party/fft2d/fft.h +tensorflow_lite_support/third_party/fft2d/fft2d.BUILD +tensorflow_lite_support/third_party/fft2d/fft2d.h +tensorflow_lite_support/third_party/google_toolbox_for_mac.BUILD +tensorflow_lite_support/third_party/icu.BUILD +tensorflow_lite_support/third_party/libyuv.BUILD +tensorflow_lite_support/third_party/libzip.BUILD +tensorflow_lite_support/third_party/pybind11.BUILD +tensorflow_lite_support/third_party/python_runtime/BUILD +tensorflow_lite_support/third_party/six.BUILD +tensorflow_lite_support/third_party/stblib.BUILD +tensorflow_lite_support/third_party/toolchains/java/BUILD +tensorflow_lite_support/third_party/utf.BUILD +tensorflow_lite_support/third_party/zlib.BUILD +tensorflow_lite_support/tools/ci_build/build_all.sh +tensorflow_lite_support/tools/ci_build/common.sh +tensorflow_lite_support/tools/ci_build/common_win.bat +tensorflow_lite_support/tools/pip_package/BUILD +tensorflow_lite_support/tools/pip_package/MANIFEST.in +tensorflow_lite_support/tools/pip_package/README +tensorflow_lite_support/tools/pip_package/build_pip_package.sh +tensorflow_lite_support/tools/pip_package/setup.py +tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py +tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py
\ No newline at end of file diff --git a/tensorflow_lite_support/tools/BUILD b/tensorflow_lite_support/tools/BUILD new file mode 100644 index 00000000..c3525ca4 --- /dev/null +++ b/tensorflow_lite_support/tools/BUILD @@ -0,0 +1,20 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +py_binary( + name = "zip_files", + srcs = ["zip_files.py"], + python_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +) + +py_library( + name = "expect_flatbuffers_installed", + srcs = [], +) diff --git a/tensorflow_lite_support/tools/build_rules/expand_template.bzl b/tensorflow_lite_support/tools/build_rules/expand_template.bzl new file mode 100644 index 00000000..717860ca --- /dev/null +++ b/tensorflow_lite_support/tools/build_rules/expand_template.bzl @@ -0,0 +1,50 @@ +"""Build macro for libzip.""" + +# forked from kythe/kythe/tools/build_rules/expand_template.bzl +def _expand_template_impl(ctx): + ctx.actions.expand_template( + template = ctx.file.template, + output = ctx.outputs.out, + substitutions = ctx.attr.substitutions, + ) + +expand_template = rule( + attrs = { + "out": attr.output(mandatory = True), + "substitutions": attr.string_dict(mandatory = True), + "template": attr.label( + mandatory = True, + allow_single_file = True, + ), + }, + output_to_genfiles = True, + implementation = _expand_template_impl, +) + +def cmake_substitutions(vars, defines = {}): + """Returns a dict of template substitutions combining `vars` and `defines`. + + Args: + vars: will be turned into a dict replacing `${key}` and `@key@` with `value`. + defines: will be turned into a dict replacing `#cmakedefine` with `#define {value}` + if present is true, otherwise `/* #undef %s /*`. + Returns: + substitutions + """ + subs = {} + for key, value in vars.items(): + subs["${%s}" % (key,)] = str(value) if value != None else "" + subs["@%s@" % (key,)] = str(value) if value != None else "" + + # TODO(shahms): Better handling of #cmakedefine delimiters and line endings to + # avoid the prefix-substitution problem. + # Potentially allow value to be: True, False, None or string. + # True/False => Same as current + # None => assume no suffix value, include \n in sub and replacement + # string => use string to lookup in vars and assume ${} or @@ tail? + for macro, present in defines.items(): + if present: + subs["#cmakedefine %s" % macro] = "#define %s" % macro + else: + subs["#cmakedefine %s" % macro] = "/* #undef %s */" % macro + return subs diff --git a/tensorflow_lite_support/tools/ci_build/build_all.sh b/tensorflow_lite_support/tools/ci_build/build_all.sh new file mode 100644 index 00000000..7e98b6c9 --- /dev/null +++ b/tensorflow_lite_support/tools/ci_build/build_all.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# External `build_all.sh` + +set -ex + +bazel build -c opt --config=monolithic \ + //tensorflow_lite_support/java:tensorflowlite_support \ + //tensorflow_lite_support/codegen/python:codegen \ + //tensorflow_lite_support/metadata/java:tensorflowlite_support_metadata_lib \ + //tensorflow_lite_support/metadata/cc:metadata_extractor \ + //tensorflow_lite_support/custom_ops/kernel:all \ + //tensorflow_lite_support/custom_ops/python:tflite_text_api + +# Build Task libraries. +bazel build -c opt --config=monolithic \ + --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ + //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base-task-api.aar \ + //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text:task-library-text \ + //tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision:task-library-vision + + +# Run Metadata tests. +bazel clean --expunge + +bazel test --test_output=all \ + //tensorflow_lite_support/metadata/python/tests:metadata_test \ + //tensorflow_lite_support/metadata/python/tests/metadata_writers:all + diff --git a/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh b/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh new file mode 100644 index 00000000..a36f97d3 --- /dev/null +++ b/tensorflow_lite_support/tools/ci_build/builds/pip_smoke_test.sh @@ -0,0 +1,114 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Pip install TensorFlow Lite Support and run basic test on the pip package. + +# Important: Use msys shell to run this script on Windows. + +set -e +set -x + +function run_smoke_test() { + VENV_TMP_DIR="$(mktemp -d)" + + if [[ "$OSTYPE" == "msys" ]]; then + VENV_TMP_DIR="$(cygpath -m $VENV_TMP_DIR)" + fi + + ${PYTHON_BIN_PATH} -m virtualenv -p ${PYTHON_BIN_PATH} "${VENV_TMP_DIR}" || \ + die "FAILED: Unable to create virtualenv" + + if [[ "$OSTYPE" == "msys" ]]; then + source "${VENV_TMP_DIR}/Scripts/activate" || \ + die "FAILED: Unable to activate virtualenv " + else + source "${VENV_TMP_DIR}/bin/activate" || \ + die "FAILED: Unable to activate virtualenv " + fi + + # install tflite-support + python -m pip install ${WHL_NAME} || \ + die "pip install (forcing to reinstall tflite-support) FAILED" + echo "Successfully installed pip package ${WHL_NAME}" + + # Download a test model + export TEST_MODEL="$(pwd)/test.tflite" + wget https://tfhub.dev/tensorflow/lite-model/mobilenet_v1_0.75_192_quantized/1/metadata/1\?lite-format\=tflite -O "$TEST_MODEL" + if [[ "$OSTYPE" == "msys" ]]; then + TEST_MODEL=$(cygpath -m $TEST_MODEL) + fi + + test_tfls_imports + + test_codegen + + # Deactivate from virtualenv. + deactivate || source deactivate || \ + die "FAILED: Unable to deactivate from existing virtualenv." + + echo "All smoke test passes!" +} + +function test_tfls_imports() { + TMP_DIR=$(mktemp -d) + pushd "${TMP_DIR}" + + # test for basic import and metadata display. + RET_VAL=$(python -c "from tflite_support import metadata; \ +md = metadata.MetadataDisplayer.with_model_file(\"$TEST_MODEL\"); \ +print(md.get_metadata_json())") + + # just check if the model name is there. + if ! [[ ${RET_VAL} == *"MobileNetV1 image classifier (quantized)"* ]]; then + echo "Unexpected return value: ${RET_VAL}" + echo "PIP smoke test on virtualenv FAILED, do not upload ${WHL_NAME}." + return 1 + fi + + RESULT=$? + + popd + return $RESULT +} + +function test_codegen() { + TMP_DIR=$(mktemp -d) + pushd "${TMP_DIR}" + + # test for basic import and metadata display. + tflite_codegen --model ${TEST_MODEL} --destination tmp + RESULT=$? + + # just check if the model name is there. + if [[ ${RESULT} -ne 0 ]]; then + echo "Unexpected return value: ${RESULT}" + echo "PIP smoke test on virtualenv FAILED, do not upload ${WHL_NAME}." + return 1 + fi + + popd + return $RESULT +} + +########################################################################### +# Main +########################################################################### +if [[ -z "${1}" ]]; then + echo "TFLite Support WHL path not given, unable to install and test." + return 1 +fi + +WHL_NAME=${1} +run_smoke_test diff --git a/tensorflow_lite_support/tools/ci_build/common.sh b/tensorflow_lite_support/tools/ci_build/common.sh new file mode 100644 index 00000000..4907bb1b --- /dev/null +++ b/tensorflow_lite_support/tools/ci_build/common.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# External `common.sh` + +# Keep in sync with tensorflow core and configure.py. +# TODO(b/158448780): Guard bazel version with IfChangeThenChange. +LATEST_BAZEL_VERSION=3.1.0 + +# Run flaky functions with retries. +# run_with_retry cmd +function run_with_retry { + eval "$1" + # If the command fails retry again in 60 seconds. + if [[ $? -ne 0 ]]; then + sleep 60 + eval "$1" + fi +} + +function die() { + echo "$@" 1>&2 ; exit 1; +} + +# A small utility to run the command and only print logs if the command fails. +# On success, all logs are hidden. +function readable_run { + # Disable debug mode to avoid printing of variables here. + set +x + result=$("$@" 2>&1) || die "$result" + echo "$@" + echo "Command completed successfully at $(date)" + set -x +} + +# TODO(b/158448780): Guard bazel installation with IfChangeThenChange. +function set_bazel_outdir { + mkdir -p /tmpfs/bazel_output + export TEST_TMPDIR=/tmpfs/bazel_output +} + +# Downloads bazelisk to ~/bin as `bazel`. +function install_bazelisk { + date + case "$(uname -s)" in + Darwin) local name=bazelisk-darwin-amd64 ;; + Linux) local name=bazelisk-linux-amd64 ;; + *) die "Unknown OS: $(uname -s)" ;; + esac + mkdir -p "$HOME/bin" + wget --no-verbose -O "$HOME/bin/bazel" \ + "https://github.com/bazelbuild/bazelisk/releases/download/v1.3.0/$name" + chmod u+x "$HOME/bin/bazel" + if [[ ! ":$PATH:" =~ :"$HOME"/bin/?: ]]; then + PATH="$HOME/bin:$PATH" + fi + set_bazel_outdir + which bazel + bazel version + date +} + +# Install the given bazel version on linux +function update_bazel_linux { + if [[ -z "$1" ]]; then + BAZEL_VERSION=${LATEST_BAZEL_VERSION} + else + BAZEL_VERSION=$1 + fi + rm -rf ~/bazel + mkdir ~/bazel + + pushd ~/bazel + readable_run wget https://github.com/bazelbuild/bazel/releases/download/"${BAZEL_VERSION}"/bazel-"${BAZEL_VERSION}"-installer-linux-x86_64.sh + chmod +x bazel-*.sh + ./bazel-"${BAZEL_VERSION}"-installer-linux-x86_64.sh --user + rm bazel-"${BAZEL_VERSION}"-installer-linux-x86_64.sh + popd + + PATH="/home/kbuilder/bin:$PATH" + set_bazel_outdir + which bazel + bazel version +} diff --git a/tensorflow_lite_support/tools/ci_build/common_win.bat b/tensorflow_lite_support/tools/ci_build/common_win.bat new file mode 100644 index 00000000..35f39a72 --- /dev/null +++ b/tensorflow_lite_support/tools/ci_build/common_win.bat @@ -0,0 +1,29 @@ +:: Copyright 2019 The TensorFlow Authors. All Rights Reserved. +:: +:: Licensed under the Apache License, Version 2.0 (the "License"); +:: you may not use this file except in compliance with the License. +:: You may obtain a copy of the License at +:: +:: http://www.apache.org/licenses/LICENSE-2.0 +:: +:: Unless required by applicable law or agreed to in writing, software +:: distributed under the License is distributed on an "AS IS" BASIS, +:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +:: See the License for the specific language governing permissions and +:: limitations under the License. +:: ============================================================================= + +:: This script is shamefully borrowed from: +:: //third_party/tensorflow/tools/ci_build/release/common_win.bat.oss + +echo on + +@REM +@REM Setup Bazel +@REM +:: Download Bazel from github and make sure its found in PATH. +SET BAZEL_VERSION=3.1.0 +md C:\tools\bazel\ +wget -q https://github.com/bazelbuild/bazel/releases/download/%BAZEL_VERSION%/bazel-%BAZEL_VERSION%-windows-x86_64.exe -O C:/tools/bazel/bazel.exe +SET PATH=C:\tools\bazel;%PATH% +bazel version diff --git a/tensorflow_lite_support/tools/ci_build/update_version.py b/tensorflow_lite_support/tools/ci_build/update_version.py new file mode 100644 index 00000000..86fa588d --- /dev/null +++ b/tensorflow_lite_support/tools/ci_build/update_version.py @@ -0,0 +1,120 @@ +# lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.. +# ============================================================================== +"""Update version code in the repo. + +We use a python script rather than GNU tools to avoid cross-platform +difficulties. + +The script takes 3 argument: + --src <path> a path pointing to the code repo. + --version <version> the new version code. + --nightly [default: false] when true, the version code will append a build + suffix (e.g. dev20201103) + +It should not run by bazel. Use it as a simple python script. +""" + +import argparse +import datetime +import os +import re + +SETUP_PY_PATH = "tensorflow_lite_support/tools/pip_package/setup.py" + + +def replace_string_in_line(search, replace, filename): + """Replace the string in every line of the file in-place.""" + with open(filename, "r") as f: + content = f.read() + with open(filename, "w") as f: + f.write(re.sub(search, replace, content)) + + +def get_current_version(path): + """Get the current version code from setup.py.""" + for line in open(os.path.join(path, SETUP_PY_PATH)): + match = re.search("^_VERSION = '([a-z0-9\\.\\-]+)'", line) + if match: + return match.group(1) + print("Cannot find current version!") + return None + + +def update_version(path, current_version, new_version): + """Update the version code in the codebase.""" + # Update setup.py + replace_string_in_line( + "_VERSION = '%s'" % current_version, + # pep440 requires such a replacement + "_VERSION = '%s'" % new_version.replace("-", "."), + os.path.join(path, SETUP_PY_PATH)) + + +class CustomTimeZone(datetime.tzinfo): + + def utcoffset(self, dt): + return -datetime.timedelta(hours=8) + + def tzname(self, dt): + return "UTC-8" + + def dst(self, dt): + return datetime.timedelta(0) + + +def remove_build_suffix(version): + """Remove build suffix (if exists) from a version.""" + if version.find("-dev") >= 0: + return version[:version.find("-dev")] + if version.find(".dev") >= 0: + return version[:version.find(".dev")] + if version.find("dev") >= 0: + return version[:version.find("dev")] + return version + + +def main(): + parser = argparse.ArgumentParser(description="Update TFLS version in repo") + parser.add_argument( + "--src", + help="a path pointing to the code repo", + required=True, + default="") + parser.add_argument("--version", help="the new SemVer code", default="") + parser.add_argument( + "--nightly", + help="if true, a build suffix will append to the version code. If " + "current version code or the <version> argument provided contains a " + "build suffix, the suffix will be replaced with the timestamp", + action="store_true") + args = parser.parse_args() + + path = args.src + current_version = get_current_version(path) + if not current_version: + return + new_version = args.version if args.version else current_version + if args.nightly: + new_version = remove_build_suffix(new_version) + # Use UTC-8 rather than uncertain local time. + d = datetime.datetime.now(tz=CustomTimeZone()) + new_version += "-dev" + d.strftime("%Y%m%d") + print("Updating version from %s to %s" % (current_version, new_version)) + update_version(path, current_version, new_version) + + +if __name__ == "__main__": + main() diff --git a/tensorflow_lite_support/tools/pip_package/BUILD b/tensorflow_lite_support/tools/pip_package/BUILD new file mode 100644 index 00000000..61df24a4 --- /dev/null +++ b/tensorflow_lite_support/tools/pip_package/BUILD @@ -0,0 +1,55 @@ +# Description: +# Tools for building the TensorFlow pip package. + +load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib") + +package(default_visibility = ["//visibility:private"]) + +COMMON_PIP_DEPS = [ + ":licenses", + "MANIFEST.in", + "README", + "setup.py", + "//tensorflow_lite_support/codegen/python:codegen", + "//tensorflow_lite_support/metadata/python:metadata", +] + +filegroup( + name = "licenses", + data = [ + "//:LICENSE", + "@org_tensorflow//:LICENSE", + ] + if_not_system_lib( + "absl_py", + [ + "@absl_py//absl:LICENSE", + "@absl_py//absl/logging:LICENSE", + "@absl_py//absl/flags:LICENSE", + "@absl_py//absl/testing:LICENSE", + "@absl_py//absl/third_party/unittest3_backport:LICENSE", + ], + ), +) + +sh_binary( + name = "build_pip_package", + srcs = ["build_pip_package.sh"], + data = COMMON_PIP_DEPS + + select({ + "@org_tensorflow//tensorflow:windows": [ + ":simple_console_for_windows", + ], + "//conditions:default": [ + ], + }), +) + +# On Windows, python binary is a zip file of runfiles tree. +# Add everything to its data dependency for generating a runfiles tree +# for building the pip package on Windows. +py_binary( + name = "simple_console_for_windows", + srcs = ["simple_console_for_windows.py"], + data = COMMON_PIP_DEPS, + srcs_version = "PY2AND3", +) diff --git a/tensorflow_lite_support/tools/pip_package/MANIFEST.in b/tensorflow_lite_support/tools/pip_package/MANIFEST.in new file mode 100644 index 00000000..e44f271f --- /dev/null +++ b/tensorflow_lite_support/tools/pip_package/MANIFEST.in @@ -0,0 +1,9 @@ +include LICENSE +include README.md +recursive-include * *.py +recursive-include * *.pyd +recursive-include * *.fbs +recursive-include * *.so +recursive-include * *.dylib +recursive-include * *.dll +recursive-include * *.lib diff --git a/tensorflow_lite_support/tools/pip_package/README b/tensorflow_lite_support/tools/pip_package/README new file mode 100644 index 00000000..1e1f9d5a --- /dev/null +++ b/tensorflow_lite_support/tools/pip_package/README @@ -0,0 +1 @@ +TensorFlow Lite Support diff --git a/tensorflow_lite_support/tools/pip_package/build_pip_package.sh b/tensorflow_lite_support/tools/pip_package/build_pip_package.sh new file mode 100755 index 00000000..2f962e3f --- /dev/null +++ b/tensorflow_lite_support/tools/pip_package/build_pip_package.sh @@ -0,0 +1,232 @@ +#!/usr/bin/env bash +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +function is_absolute { + [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]] +} + +function real_path() { + is_absolute "$1" && echo "$1" || echo "$PWD/${1#./}" +} + +function move_to_root_if_exists () { + arg_to_move="$1" + if [ -e "${arg_to_move}" ]; then + mv ${arg_to_move} ./ + fi +} + +function reorganize_includes() { + TMPDIR="${1%/}" +} + +PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" +function is_windows() { + if [[ "${PLATFORM}" =~ (cygwin|mingw32|mingw64|msys)_nt* ]]; then + true + else + false + fi +} + +function prepare_src() { + if [ $# -lt 1 ] ; then + echo "No destination dir provided" + exit 1 + fi + + TMPDIR="${1%/}" + mkdir -p "$TMPDIR" + EXTERNAL_INCLUDES="${TMPDIR}/tflite_support/include/external" + + echo $(date) : "=== Preparing sources in dir: ${TMPDIR}" + + if [ ! -d bazel-bin/tensorflow_lite_support ]; then + echo "Could not find bazel-bin. Did you run from the root of the build tree?" + exit 1 + fi + + if is_windows; then + rm -rf ./bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows_unzip + mkdir -p ./bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows_unzip + echo "Unzipping simple_console_for_windows.zip to create runfiles tree..." + unzip -o -q ./bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows.zip -d ./bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows_unzip + echo "Unzip finished." + # runfiles structure after unzip the python binary + RUNFILES=bazel-bin/tensorflow_lite_support/tools/pip_package/simple_console_for_windows_unzip/runfiles/org_tensorflow_lite_support + + # TODO(b/165872313): Investigate the case and remove the hack. + # On Windows, __init__.py are not auto genereated at directories that only + # contains Pybind libraries. + touch "$RUNFILES/tensorflow_lite_support/metadata/cc/__init__.py" + touch "$RUNFILES/tensorflow_lite_support/metadata/cc/python/__init__.py" + touch "$RUNFILES/tensorflow_lite_support/metadata/flatbuffers_lib/__init__.py" + else + RUNFILES=bazel-bin/tensorflow_lite_support/tools/pip_package/build_pip_package.runfiles/org_tensorflow_lite_support + fi + + cp "$RUNFILES/LICENSE" "${TMPDIR}" + cp -R "$RUNFILES/tensorflow_lite_support" "${TMPDIR}" + + reorganize_includes "${TMPDIR}" + + cp tensorflow_lite_support/tools/pip_package/MANIFEST.in ${TMPDIR} + cp tensorflow_lite_support/tools/pip_package/README ${TMPDIR}/README.md + cp tensorflow_lite_support/tools/pip_package/setup.py ${TMPDIR} + + # A helper entry. + mkdir ${TMPDIR}/tflite_support + cp tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py ${TMPDIR}/tflite_support/__init__.py +} + +function build_wheel() { + if [ $# -lt 2 ] ; then + echo "No src and dest dir provided" + exit 1 + fi + + TMPDIR="$1" + DEST="$2" + PKG_NAME_FLAG="$3" + + # Before we leave the top-level directory, make sure we know how to + # call python. + if [[ -e tools/python_bin_path.sh ]]; then + source tools/python_bin_path.sh + fi + + pushd ${TMPDIR} > /dev/null + + rm -f MANIFEST + echo $(date) : "=== Building wheel" + "${PYTHON_BIN_PATH:-python}" setup.py bdist_wheel ${PKG_NAME_FLAG} >/dev/null + mkdir -p ${DEST} + cp dist/* ${DEST} + popd > /dev/null + echo $(date) : "=== Output wheel file is in: ${DEST}" +} + +function usage() { + echo "Usage:" + echo "$0 [--src srcdir] [--dst dstdir] [options]" + echo "$0 dstdir [options]" + echo "" + echo " --src prepare sources in srcdir" + echo " will use temporary dir if not specified" + echo "" + echo " --dst build wheel in dstdir" + echo " if dstdir is not set do not build, only prepare sources" + echo "" + echo " Options:" + echo " --project_name <name> set project name to <name>" + echo " --version <version> reset the pip package version to <version>" + echo " --nightly_flag build TFLite Support nightly" + echo "" + echo "When using bazel, add the following flag: --run_under=\"cd \$PWD && \"" + echo "" + exit 1 +} + +function main() { + PKG_NAME_FLAG="" + PROJECT_NAME="" + NIGHTLY_BUILD=0 + SRCDIR="" + DSTDIR="" + CLEANSRC=1 + VERSION="" + while true; do + if [[ "$1" == "--help" ]]; then + usage + exit 1 + elif [[ "$1" == "--nightly_flag" ]]; then + NIGHTLY_BUILD=1 + elif [[ "$1" == "--project_name" ]]; then + shift + if [[ -z "$1" ]]; then + break + fi + PROJECT_NAME="$1" + elif [[ "$1" == "--version" ]]; then + shift + if [[ -z "$1" ]]; then + break + fi + VERSION="$1" + elif [[ "$1" == "--src" ]]; then + shift + SRCDIR="$(real_path $1)" + CLEANSRC=0 + elif [[ "$1" == "--dst" ]]; then + shift + DSTDIR="$(real_path $1)" + else + echo "Unrecognized flag: $1" + usage + exit 1 + fi + shift + + if [[ -z "$1" ]]; then + break + fi + done + + if [[ -z "$DSTDIR" ]] && [[ -z "$SRCDIR" ]]; then + echo "No destination dir provided" + usage + exit 1 + fi + + if [[ -z "$SRCDIR" ]]; then + # make temp srcdir if none set + SRCDIR="$(mktemp -d -t tmp.XXXXXXXXXX)" + fi + + if [[ -z "$DSTDIR" ]]; then + # only want to prepare sources + exit + fi + + if [[ -n ${PROJECT_NAME} ]]; then + PKG_NAME_FLAG="--project_name ${PROJECT_NAME}" + elif [[ ${NIGHTLY_BUILD} == "1" ]]; then + PKG_NAME_FLAG="--project_name tflite_support_nightly" + fi + + if [[ ${NIGHTLY_BUILD} == "1" ]]; then + # we use a script to update versions to avoid any tool differences on different platforms. + if [[ ! -z ${VERSION} ]]; then + python tensorflow_lite_support/tools/ci_build/update_version.py --src "." --version ${VERSION} --nightly + else + python tensorflow_lite_support/tools/ci_build/update_version.py --src "." --nightly + fi + elif [[ ! -z ${VERSION} ]]; then + python tensorflow_lite_support/tools/ci_build/update_version.py --src "." --version ${VERSION} + fi + + prepare_src "$SRCDIR" + + build_wheel "$SRCDIR" "$DSTDIR" "$PKG_NAME_FLAG" + + if [[ $CLEANSRC -ne 0 ]]; then + rm -rf "${TMPDIR}" + fi +} + +main "$@" diff --git a/tensorflow_lite_support/tools/pip_package/setup.py b/tensorflow_lite_support/tools/pip_package/setup.py new file mode 100644 index 00000000..460c5057 --- /dev/null +++ b/tensorflow_lite_support/tools/pip_package/setup.py @@ -0,0 +1,154 @@ +# lint as: python3 +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.. +# ============================================================================== +"""TFLite Support is a toolkit that helps users to develop ML and deploy TFLite models onto mobile devices. + +This PyPI package includes the Python bindings for following features: + + - Metadata schemas: wraps TFLite model schema and metadata schema in Python. + - Metadata populator and displayer: can be used to populate the metadata and + associated files into the model, as well as converting the populated metadata + into the json format. + - Android Codegen tool: generates the Java model interface used in Android for + a particular model. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import fnmatch +import os +import re +import sys + +from setuptools import Command +from setuptools import find_packages +from setuptools import setup +from setuptools.command.install import install as InstallCommandBase +from setuptools.dist import Distribution + +# This version string is semver compatible, but incompatible with pip. +# For pip, we will remove all '-' characters from this string, and use the +# result for pip. +_VERSION = '0.1.0' + +SETUP_PACKAGES = [ + 'pybind11 >= 2.6.0', +] + +REQUIRED_PACKAGES = [ + 'absl-py >= 0.7.0', + 'numpy >= 1.16.0', + 'flatbuffers >= 1.12', +] + SETUP_PACKAGES + +project_name = 'tflite-support' +if '--project_name' in sys.argv: + project_name_idx = sys.argv.index('--project_name') + project_name = sys.argv[project_name_idx + 1] + sys.argv.remove('--project_name') + sys.argv.pop(project_name_idx) + +DOCLINES = __doc__.split('\n') + +CONSOLE_SCRIPTS = [ + 'tflite_codegen = tensorflow_lite_support.codegen.python.codegen:main', +] + + +class BinaryDistribution(Distribution): + + def has_ext_modules(self): + return True + + +class InstallCommand(InstallCommandBase): + """Override the dir where the headers go.""" + + def finalize_options(self): + ret = InstallCommandBase.finalize_options(self) + self.install_lib = self.install_platlib + return ret + + +def find_files(pattern, root): + """Return all the files matching pattern below root dir.""" + for dirpath, _, files in os.walk(root): + for filename in fnmatch.filter(files, pattern): + yield os.path.join(dirpath, filename) + + +so_lib_paths = [ + i for i in os.listdir('.') + if os.path.isdir(i) and fnmatch.fnmatch(i, '_solib_*') +] + +matches = [] +for path in so_lib_paths: + matches.extend(['../' + x for x in find_files('*', path) if '.py' not in x]) + +EXTENSIONS = ['codegen/_pywrap_codegen.so'] + +headers = () + +setup( + name=project_name, + version=_VERSION.replace('-', ''), + description=DOCLINES[0], + long_description='\n'.join(DOCLINES), + long_description_content_type='text/markdown', + url='https://www.tensorflow.org/', + download_url='https://github.com/tensorflow/tflite-support/tags', + author='Google, LLC.', + author_email='packages@tensorflow.org', + # Contained modules and scripts. + packages=find_packages(), + entry_points={ + 'console_scripts': CONSOLE_SCRIPTS, + }, + headers=headers, + setup_requires=SETUP_PACKAGES, + install_requires=REQUIRED_PACKAGES, + tests_require=REQUIRED_PACKAGES, + # Add in any packaged data. + include_package_data=True, + package_data={ + 'tflite-support': EXTENSIONS + matches, + }, + zip_safe=False, + distclass=BinaryDistribution, + cmdclass={ + 'install': InstallCommand, + }, + # PyPI package information. + classifiers=sorted([ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ]), + license='Apache 2.0', +) diff --git a/tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py b/tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py new file mode 100644 index 00000000..106528bb --- /dev/null +++ b/tensorflow_lite_support/tools/pip_package/simple_console_for_windows.py @@ -0,0 +1,33 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Start a simple interactive console with TensorFlow available.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import code +import sys + + +def main(_): + """Run an interactive console.""" + code.interact() + return 0 + + +if __name__ == '__main__': + sys.exit(main(sys.argv)) diff --git a/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py b/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py new file mode 100644 index 00000000..c5bf3de5 --- /dev/null +++ b/tensorflow_lite_support/tools/pip_package/tflite_support.__init__.py @@ -0,0 +1,28 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An import entry for the TFLite Support project. + +In the original project structure, all python targets are accessed by paths like +tensorflow_lite_support.metadata.metadata.MetadataDisplayer, which is verbose +and deep. This file provides some shortcuts. It's also compatible with our first +version Pip package. + +In pip build, this file will be renamed as tflite_support/__init__.py. +""" + +import flatbuffers +from tensorflow_lite_support.metadata import metadata_schema_py_generated +from tensorflow_lite_support.metadata import schema_py_generated +from tensorflow_lite_support.metadata.python import metadata diff --git a/tensorflow_lite_support/tools/zip_files.py b/tensorflow_lite_support/tools/zip_files.py new file mode 100644 index 00000000..9dc66236 --- /dev/null +++ b/tensorflow_lite_support/tools/zip_files.py @@ -0,0 +1,41 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Lint as: python3 +"""Creates a zip package of the files passed in.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import zipfile + +from absl import app +from absl import flags + +FLAGS = flags.FLAGS +flags.DEFINE_string("export_zip_path", None, "Path to zip file.") +flags.DEFINE_string("file_directory", None, "Path to the files to be zipped.") + + +def main(_): + with zipfile.ZipFile(FLAGS.export_zip_path, mode="w") as zf: + for root, _, files in os.walk(FLAGS.file_directory): + for f in files: + if f.endswith(".java"): + zf.write(os.path.join(root, f)) + + +if __name__ == "__main__": + app.run(main) diff --git a/third_party/BUILD b/third_party/BUILD new file mode 100644 index 00000000..fe756e1b --- /dev/null +++ b/third_party/BUILD @@ -0,0 +1 @@ +licenses(["notice"]) # Apache 2.0 diff --git a/third_party/android/BUILD b/third_party/android/BUILD new file mode 100644 index 00000000..fd69d4ba --- /dev/null +++ b/third_party/android/BUILD @@ -0,0 +1 @@ +# Placeholder to make bazel treat it as a package. diff --git a/third_party/android/android.bzl.tpl b/third_party/android/android.bzl.tpl new file mode 100644 index 00000000..e6ed4994 --- /dev/null +++ b/third_party/android/android.bzl.tpl @@ -0,0 +1,9 @@ +"""Set up configurable Android SDK and NDK dependencies.""" + +def android_workspace(): + # String for replacement in Bazel template. + # These will either be replaced by android_sdk_repository if various ENV + # variables are set when `local_config_android` repo_rule is run, or they + # will be replaced by noops otherwise. + MAYBE_ANDROID_SDK_REPOSITORY + MAYBE_ANDROID_NDK_REPOSITORY diff --git a/third_party/android/android_configure.BUILD.tpl b/third_party/android/android_configure.BUILD.tpl new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/third_party/android/android_configure.BUILD.tpl diff --git a/third_party/android/android_configure.bzl b/third_party/android/android_configure.bzl new file mode 100644 index 00000000..2fd2d807 --- /dev/null +++ b/third_party/android/android_configure.bzl @@ -0,0 +1,95 @@ +"""Repository rule for Android SDK and NDK autoconfiguration. + +`android_configure` depends on the following environment variables: + + * `ANDROID_NDK_HOME`: Location of Android NDK root. + * `ANDROID_SDK_HOME`: Location of Android SDK root. + * `ANDROID_SDK_API_LEVEL`: Desired Android SDK API version. + * `ANDROID_NDK_API_LEVEL`: Desired Android NDK API version. + * `ANDROID_BUILD_TOOLS_VERSION`: Desired Android build tools version. + + +Writes Android SDK and NDK rules. + +Add the following to your WORKSPACE FILE: + +```python +android_configure(name = "local_config_android") +``` + +Args: + name: A unique name for this workspace rule. +""" + +_ANDROID_NDK_HOME = "ANDROID_NDK_HOME" +_ANDROID_SDK_HOME = "ANDROID_SDK_HOME" +_ANDROID_NDK_API_VERSION = "ANDROID_NDK_API_LEVEL" +_ANDROID_SDK_API_VERSION = "ANDROID_SDK_API_LEVEL" +_ANDROID_BUILD_TOOLS_VERSION = "ANDROID_BUILD_TOOLS_VERSION" + +_ANDROID_SDK_REPO_TEMPLATE = """ + native.android_sdk_repository( + name="androidsdk", + path="%s", + api_level=%s, + build_tools_version="%s", + ) +""" + +_ANDROID_NDK_REPO_TEMPLATE = """ + native.android_ndk_repository( + name="androidndk", + path="%s", + api_level=%s, + ) +""" + +def _android_autoconf_impl(repository_ctx): + """Implementation of the android_autoconf repository rule.""" + sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME) + sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION) + build_tools_version = repository_ctx.os.environ.get( + _ANDROID_BUILD_TOOLS_VERSION, + ) + ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME) + ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION) + + sdk_rule = "" + if all([sdk_home, sdk_api_level, build_tools_version]): + sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % ( + sdk_home, + sdk_api_level, + build_tools_version, + ) + + ndk_rule = "" + if all([ndk_home, ndk_api_level]): + ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level) + + if ndk_rule == "" and sdk_rule == "": + sdk_rule = "pass" + # TODO(xunkai): Add interactive configure script. + + repository_ctx.template( + "BUILD", + Label("//third_party/android:android_configure.BUILD.tpl"), + ) + repository_ctx.template( + "android.bzl", + Label("//third_party/android:android.bzl.tpl"), + substitutions = { + "MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule, + "MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule, + }, + ) + +android_configure = repository_rule( + implementation = _android_autoconf_impl, + environ = [ + _ANDROID_SDK_API_VERSION, + _ANDROID_NDK_API_VERSION, + _ANDROID_BUILD_TOOLS_VERSION, + _ANDROID_NDK_HOME, + _ANDROID_SDK_HOME, + ], +) diff --git a/third_party/com_google_absl.BUILD b/third_party/com_google_absl.BUILD new file mode 100644 index 00000000..8fca145f --- /dev/null +++ b/third_party/com_google_absl.BUILD @@ -0,0 +1,5 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache + +exports_files(["LICENSE"]) diff --git a/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff b/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff new file mode 100644 index 00000000..0cd2dffa --- /dev/null +++ b/third_party/com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff @@ -0,0 +1,14 @@ +diff --git a/absl/time/internal/cctz/BUILD.bazel b/absl/time/internal/cctz/BUILD.bazel +index 9fceffe..e7f9d01 100644 +--- a/absl/time/internal/cctz/BUILD.bazel ++++ b/absl/time/internal/cctz/BUILD.bazel +@@ -69,8 +69,5 @@ cc_library( + "include/cctz/zone_info_source.h", + ], + linkopts = select({ +- ":osx": [ +- "-framework Foundation", +- ], + ":ios": [ + "-framework Foundation", + ],
\ No newline at end of file diff --git a/third_party/com_google_protobuf_fixes.diff b/third_party/com_google_protobuf_fixes.diff new file mode 100644 index 00000000..b9bc17ea --- /dev/null +++ b/third_party/com_google_protobuf_fixes.diff @@ -0,0 +1,140 @@ +diff --git a/BUILD b/BUILD +index 79871d621..51b3a063f 100644 +--- a/BUILD ++++ b/BUILD +@@ -26,7 +26,7 @@ config_setting( + # ZLIB configuration + ################################################################################ + +-ZLIB_DEPS = ["@zlib//:zlib"] ++ZLIB_DEPS = ["@zlib"] + + ################################################################################ + # Protobuf Runtime Library +@@ -157,6 +157,7 @@ cc_library( + includes = ["src/"], + linkopts = LINK_OPTS, + visibility = ["//visibility:public"], ++ alwayslink = 1, + ) + + PROTOBUF_DEPS = select({ +@@ -230,6 +231,7 @@ cc_library( + linkopts = LINK_OPTS, + visibility = ["//visibility:public"], + deps = [":protobuf_lite"] + PROTOBUF_DEPS, ++ alwayslink = 1, + ) + + # This provides just the header files for use in projects that need to build +@@ -318,13 +320,13 @@ cc_proto_library( + + [native_cc_proto_library( + name = proto + "_cc_proto", +- deps = [proto + "_proto"], + visibility = ["//visibility:private"], ++ deps = [proto + "_proto"], + ) for proto in WELL_KNOWN_PROTO_MAP.keys()] + + cc_proto_blacklist_test( + name = "cc_proto_blacklist_test", +- deps = [proto + "_cc_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()] ++ deps = [proto + "_cc_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()], + ) + + ################################################################################ +@@ -900,7 +902,6 @@ py_proto_library( + py_extra_srcs = glob(["python/**/__init__.py"]), + py_libs = [ + ":python_srcs", +- "@six//:six", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +@@ -1002,7 +1003,9 @@ cc_library( + # Note: We use `native_proto_common` here because we depend on an implementation-detail of + # `proto_lang_toolchain`, which may not be available on `proto_common`. + reject_blacklisted_files = hasattr(native_proto_common, "proto_lang_toolchain_rejects_files_do_not_use_or_we_will_break_you_without_mercy") ++ + cc_toolchain_blacklisted_protos = [proto + "_proto" for proto in WELL_KNOWN_PROTO_MAP.keys()] if reject_blacklisted_files else [":well_known_protos"] ++ + proto_lang_toolchain( + name = "cc_toolchain", + blacklisted_protos = cc_toolchain_blacklisted_protos, +diff --git a/protobuf.bzl b/protobuf.bzl +index 829464d44..4ac23594b 100644 +--- a/protobuf.bzl ++++ b/protobuf.bzl +@@ -87,6 +87,8 @@ def _proto_gen_impl(ctx): + for dep in ctx.attr.deps: + import_flags += dep.proto.import_flags + deps += dep.proto.deps ++ import_flags = depset(import_flags).to_list() ++ deps = depset(deps).to_list() + + if not ctx.attr.gen_cc and not ctx.attr.gen_py and not ctx.executable.plugin: + return struct( +diff --git a/src/google/protobuf/io/gzip_stream.h b/src/google/protobuf/io/gzip_stream.h +index b1ce1d36c..d5d560ea7 100644 +--- a/src/google/protobuf/io/gzip_stream.h ++++ b/src/google/protobuf/io/gzip_stream.h +@@ -47,10 +47,12 @@ + #include <google/protobuf/stubs/common.h> + #include <google/protobuf/io/zero_copy_stream.h> + #include <google/protobuf/port.h> +-#include <zlib.h> +- + #include <google/protobuf/port_def.inc> + ++#if HAVE_ZLIB ++#include <zlib.h> ++#endif // HAVE_ZLIB ++ + namespace google { + namespace protobuf { + namespace io { +@@ -76,8 +78,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream { + virtual ~GzipInputStream(); + + // Return last error message or NULL if no error. ++#if HAVE_ZLIB + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } ++#endif // HAVE_ZLIB + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size); +@@ -90,8 +94,10 @@ class PROTOBUF_EXPORT GzipInputStream : public ZeroCopyInputStream { + + ZeroCopyInputStream* sub_stream_; + ++ #if HAVE_ZLIB + z_stream zcontext_; + int zerror_; ++ #endif // HAVE_ZLIB + + void* output_buffer_; + void* output_position_; +@@ -142,9 +148,11 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream { + + virtual ~GzipOutputStream(); + ++#if HAVE_ZLIB + // Return last error message or NULL if no error. + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } ++#endif // HAVE_ZLIB + + // Flushes data written so far to zipped data in the underlying stream. + // It is the caller's responsibility to flush the underlying stream if +@@ -177,8 +185,10 @@ class PROTOBUF_EXPORT GzipOutputStream : public ZeroCopyOutputStream { + void* sub_data_; + int sub_data_size_; + ++#if HAVE_ZLIB + z_stream zcontext_; + int zerror_; ++#endif //HAVE_ZLIB + void* input_buffer_; + size_t input_buffer_length_; + diff --git a/third_party/darts_clone.BUILD b/third_party/darts_clone.BUILD new file mode 100644 index 00000000..1d95ec2f --- /dev/null +++ b/third_party/darts_clone.BUILD @@ -0,0 +1,15 @@ +# Description: +# Darts-clone is a clone of Darts (Double-ARray Trie System). + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "darts_clone", + hdrs = [ + "include/darts.h", + ], +) diff --git a/third_party/fft2d/BUILD b/third_party/fft2d/BUILD new file mode 100644 index 00000000..863a1cef --- /dev/null +++ b/third_party/fft2d/BUILD @@ -0,0 +1,48 @@ +# Headers for 2D Fast Fourier Transform package +# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html +# This is a separate package because the original downloaded archive doesn't +# contain any header files. + +package( + default_visibility = ["//visibility:public"], +) + +# Unrestricted use; can only distribute original package. +# See fft/readme.txt +licenses(["notice"]) + +exports_files(["LICENSE"]) + +cc_library( + name = "fft2d_headers", + srcs = [ + "fft.h", + "fft2d.h", + ], +) + +objc_library( + name = "fft2d_headersd_ios", + srcs = [ + "fft.h", + "fft2d.h", + ], +) + +# Export the source code so that it could be compiled for Andoid native apps. +filegroup( + name = "fft2d_headers_srcs", + srcs = [ + "fft.h", + "fft2d.h", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = ["**/OWNERS"], + ), + visibility = ["//third_party/tensorflow:__subpackages__"], +) diff --git a/third_party/fft2d/LICENSE b/third_party/fft2d/LICENSE new file mode 100644 index 00000000..2bd85506 --- /dev/null +++ b/third_party/fft2d/LICENSE @@ -0,0 +1,3 @@ +Copyright(C) 1997,2001 Takuya OOURA (email: ooura@kurims.kyoto-u.ac.jp). +You may use, copy, modify this code for any purpose and +without fee. You may distribute this ORIGINAL package. diff --git a/third_party/fft2d/fft.h b/third_party/fft2d/fft.h new file mode 100644 index 00000000..36d838b7 --- /dev/null +++ b/third_party/fft2d/fft.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Declarations for 1D FFT routines in third_party/fft2d/fft2d. + +#ifndef FFT2D_FFT_H__ +#define FFT2D_FFT_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +extern void cdft(int, int, double *, int *, double *); +extern void rdft(int, int, double *, int *, double *); +extern void ddct(int, int, double *, int *, double *); +extern void ddst(int, int, double *, int *, double *); +extern void dfct(int, double *, double *, int *, double *); +extern void dfst(int, double *, double *, int *, double *); + +#ifdef __cplusplus +} +#endif + +#endif // FFT2D_FFT_H__ diff --git a/third_party/fft2d/fft2d.BUILD b/third_party/fft2d/fft2d.BUILD new file mode 100644 index 00000000..9fa5097f --- /dev/null +++ b/third_party/fft2d/fft2d.BUILD @@ -0,0 +1,45 @@ +# 2D Fast Fourier Transform package +# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html + +package( + default_visibility = ["//visibility:public"], +) + +# Unrestricted use; can only distribute original package. +licenses(["notice"]) + +exports_files(["readme2d.txt"]) + +FFT2D_SRCS = [ + "fftsg.c", + "fftsg2d.c", +] + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, +) + +# This is the main 2D FFT library. The 2D FFTs in this library call +# 1D FFTs. In addition, fast DCTs are provided for the special case +# of 8x8 and 16x16. This code in this library is referred to as +# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft2d.html. +cc_library( + name = "fft2d", + srcs = FFT2D_SRCS, + linkopts = select({ + ":windows": [], + "//conditions:default": ["-lm"], + }), +) + +objc_library( + name = "fft2d_ios", + srcs = FFT2D_SRCS, +) + +# Export the source code so that it could be compiled for Andoid native apps. +filegroup( + name = "fft2d_srcs", + srcs = FFT2D_SRCS, +) diff --git a/third_party/fft2d/fft2d.h b/third_party/fft2d/fft2d.h new file mode 100644 index 00000000..d587b3b4 --- /dev/null +++ b/third_party/fft2d/fft2d.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Declarations for 2D FFT routines in third_party/fft2d/fft2d. + +#ifndef FFT2D_FFT_H__ +#define FFT2D_FFT_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +extern void cdft2d(int, int, int, double **, double *, int *, double *); +extern void rdft2d(int, int, int, double **, double *, int *, double *); +extern void ddct2d(int, int, int, double **, double *, int *, double *); +extern void ddst2d(int, int, int, double **, double *, int *, double *); +extern void ddct8x8s(int isgn, double **a); +extern void ddct16x16s(int isgn, double **a); + +#ifdef __cplusplus +} +#endif + +#endif // FFT2D_FFT_H__ diff --git a/third_party/flatbuffers/BUILD b/third_party/flatbuffers/BUILD new file mode 100644 index 00000000..82bab3ff --- /dev/null +++ b/third_party/flatbuffers/BUILD @@ -0,0 +1 @@ +# This empty BUILD file is required to make Bazel treat this directory as a package. diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel new file mode 100644 index 00000000..1ee46f05 --- /dev/null +++ b/third_party/flatbuffers/BUILD.bazel @@ -0,0 +1,140 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE.txt"]) + +licenses(["notice"]) + +config_setting( + name = "freebsd", + values = {"cpu": "freebsd"}, +) + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, +) + +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") + +# Public flatc library to compile flatbuffer files at runtime. +cc_library( + name = "flatbuffers", + hdrs = ["//:public_headers"], + linkstatic = 1, + strip_include_prefix = "/include", + visibility = ["//visibility:public"], + deps = ["//src:flatbuffers"], +) + +# Public C++ headers for the Flatbuffers library. +filegroup( + name = "public_headers", + srcs = [ + "include/flatbuffers/base.h", + "include/flatbuffers/code_generators.h", + "include/flatbuffers/flatbuffers.h", + "include/flatbuffers/flexbuffers.h", + "include/flatbuffers/hash.h", + "include/flatbuffers/idl.h", + "include/flatbuffers/minireflect.h", + "include/flatbuffers/reflection.h", + "include/flatbuffers/reflection_generated.h", + "include/flatbuffers/registry.h", + "include/flatbuffers/stl_emulation.h", + "include/flatbuffers/util.h", + ], + visibility = ["//:__subpackages__"], +) + +# Public flatc compiler library. +cc_library( + name = "flatc_library", + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + "@flatbuffers//src:flatc_library", + ], +) + +# Public flatc compiler. +cc_binary( + name = "flatc", + linkopts = select({ + ":freebsd": [ + "-lm", + ], + ":windows": [], + "//conditions:default": [ + "-lm", + "-ldl", + ], + }), + visibility = ["//visibility:public"], + deps = [ + "@flatbuffers//src:flatc", + ], +) + +filegroup( + name = "flatc_headers", + srcs = [ + "include/flatbuffers/flatc.h", + ], + visibility = ["//:__subpackages__"], +) + +# Library used by flatbuffer_cc_library rules. +cc_library( + name = "runtime_cc", + hdrs = [ + "include/flatbuffers/base.h", + "include/flatbuffers/flatbuffers.h", + "include/flatbuffers/flexbuffers.h", + "include/flatbuffers/stl_emulation.h", + "include/flatbuffers/util.h", + ], + linkstatic = 1, + strip_include_prefix = "/include", + visibility = ["//visibility:public"], +) + +filegroup( + name = "runtime_py_srcs", + srcs = [ + "python/flatbuffers/__init__.py", + "python/flatbuffers/builder.py", + "python/flatbuffers/compat.py", + "python/flatbuffers/encode.py", + "python/flatbuffers/number_types.py", + "python/flatbuffers/packer.py", + "python/flatbuffers/table.py", + "python/flatbuffers/util.py", + ], +) + +py_library( + name = "runtime_py", + srcs = [":runtime_py_srcs"], + visibility = ["//visibility:public"], +) + +filegroup( + name = "runtime_java_srcs", + srcs = glob(["java/com/google/flatbuffers/**/*.java"]), +) + +java_library( + name = "runtime_java", + srcs = [":runtime_java_srcs"], + visibility = ["//visibility:public"], +) + +android_library( + name = "runtime_android", + srcs = [":runtime_java_srcs"], + visibility = ["//visibility:public"], +) diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl new file mode 100644 index 00000000..986b8d12 --- /dev/null +++ b/third_party/flatbuffers/build_defs.bzl @@ -0,0 +1,617 @@ +"""BUILD rules for generating flatbuffer files.""" + +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +flatc_path = "@flatbuffers//:flatc" +zip_files = "//tensorflow_lite_support/tools:zip_files" + +DEFAULT_INCLUDE_PATHS = [ + "./", + "$(GENDIR)", + "$(BINDIR)", +] + +DEFAULT_FLATC_ARGS = [ + "--no-union-value-namespacing", + "--gen-object-api", +] + +def flatbuffer_library_public( + name, + srcs, + outs, + language_flag, + out_prefix = "", + includes = [], + include_paths = [], + compatible_with = [], + flatc_args = DEFAULT_FLATC_ARGS, + reflection_name = "", + reflection_visibility = None, + output_to_bindir = False): + """Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler. + + Outs: + filegroup(name): all generated source files. + Fileset([reflection_name]): (Optional) all generated reflection binaries. + + Args: + name: Rule name. + srcs: Source .fbs files. Sent in order to the compiler. + outs: Output files from flatc. + language_flag: Target language flag. One of [-c, -j, -js]. + out_prefix: Prepend this path to the front of all generated files except on + single source targets. Usually is a directory name. + includes: Optional, list of filegroups of schemas that the srcs depend on. + include_paths: Optional, list of paths the includes files can be found in. + compatible_with: Optional, passed to genrule for environments this rule + can be built for. + flatc_args: Optional, list of additional arguments to pass to flatc. + reflection_name: Optional, if set this will generate the flatbuffer + reflection binaries for the schemas. + reflection_visibility: The visibility of the generated reflection Fileset. + output_to_bindir: Passed to genrule for output to bin directory. + """ + include_paths_cmd = ["-I %s" % (s) for s in include_paths] + + # '$(@D)' when given a single source target will give the appropriate + # directory. Appending 'out_prefix' is only necessary when given a build + # target with multiple sources. + output_directory = ( + ("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)") + ) + genrule_cmd = " ".join([ + "for f in $(SRCS); do", + "$(location %s)" % (flatc_path), + " ".join(flatc_args), + " ".join(include_paths_cmd), + language_flag, + output_directory, + "$$f;", + "done", + ]) + native.genrule( + name = name, + srcs = srcs, + outs = outs, + output_to_bindir = output_to_bindir, + compatible_with = compatible_with, + tools = includes + [flatc_path], + cmd = genrule_cmd, + message = "Generating flatbuffer files for %s:" % (name), + ) + if reflection_name: + reflection_genrule_cmd = " ".join([ + "for f in $(SRCS); do", + "$(location %s)" % (flatc_path), + "-b --schema", + " ".join(flatc_args), + " ".join(include_paths_cmd), + language_flag, + output_directory, + "$$f;", + "done", + ]) + reflection_outs = [ + (out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1]) + for s in srcs + ] + native.genrule( + name = "%s_srcs" % reflection_name, + srcs = srcs, + outs = reflection_outs, + output_to_bindir = output_to_bindir, + compatible_with = compatible_with, + tools = includes + [flatc_path], + cmd = reflection_genrule_cmd, + message = "Generating flatbuffer reflection binary for %s:" % (name), + ) + # TODO(b/114456773): Make bazel rules proper and supported by flatbuffer + # Have to comment this since FilesetEntry is not supported in bazel + # skylark. + # native.Fileset( + # name = reflection_name, + # out = "%s_out" % reflection_name, + # entries = [ + # native.FilesetEntry(files = reflection_outs), + # ], + # visibility = reflection_visibility, + # compatible_with = compatible_with, + # ) + +def flatbuffer_cc_library( + name, + srcs, + srcs_filegroup_name = "", + out_prefix = "", + includes = [], + include_paths = [], + compatible_with = [], + flatc_args = DEFAULT_FLATC_ARGS, + visibility = None, + srcs_filegroup_visibility = None, + gen_reflections = False): + '''A cc_library with the generated reader/writers for the given flatbuffer definitions. + + Outs: + filegroup([name]_srcs): all generated .h files. + filegroup(srcs_filegroup_name if specified, or [name]_includes if not): + Other flatbuffer_cc_library's can pass this in for their `includes` + parameter, if they depend on the schemas in this library. + Fileset([name]_reflection): (Optional) all generated reflection binaries. + cc_library([name]): library with sources and flatbuffers deps. + + Remarks: + ** Because the genrule used to call flatc does not have any trivial way of + computing the output list of files transitively generated by includes and + --gen-includes (the default) being defined for flatc, the --gen-includes + flag will not work as expected. The way around this is to add a dependency + to the flatbuffer_cc_library defined alongside the flatc included Fileset. + For example you might define: + + flatbuffer_cc_library( + name = "my_fbs", + srcs = [ "schemas/foo.fbs" ], + includes = [ "//third_party/bazz:bazz_fbs_includes" ], + ) + + In which foo.fbs includes a few files from the Fileset defined at + //third_party/bazz:bazz_fbs_includes. When compiling the library that + includes foo_generated.h, and therefore has my_fbs as a dependency, it + will fail to find any of the bazz *_generated.h files unless you also + add bazz's flatbuffer_cc_library to your own dependency list, e.g.: + + cc_library( + name = "my_lib", + deps = [ + ":my_fbs", + "//third_party/bazz:bazz_fbs" + ], + ) + + Happy dependent Flatbuffering! + + Args: + name: Rule name. + srcs: Source .fbs files. Sent in order to the compiler. + srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this + filegroup into the `includes` parameter of any other + flatbuffer_cc_library that depends on this one's schemas. + out_prefix: Prepend this path to the front of all generated files. Usually + is a directory name. + includes: Optional, list of filegroups of schemas that the srcs depend on. + ** SEE REMARKS BELOW ** + include_paths: Optional, list of paths the includes files can be found in. + compatible_with: Optional, passed to genrule for environments this rule + can be built for + flatc_args: Optional list of additional arguments to pass to flatc + (e.g. --gen-mutable). + visibility: The visibility of the generated cc_library. By default, use the + default visibility of the project. + srcs_filegroup_visibility: The visibility of the generated srcs filegroup. + By default, use the value of the visibility parameter above. + gen_reflections: Optional, if true this will generate the flatbuffer + reflection binaries for the schemas. + ''' + output_headers = [ + (out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1]) + for s in srcs + ] + reflection_name = "%s_reflection" % name if gen_reflections else "" + + flatbuffer_library_public( + name = "%s_srcs" % (name), + srcs = srcs, + outs = output_headers, + language_flag = "-c", + out_prefix = out_prefix, + includes = includes, + include_paths = include_paths, + compatible_with = compatible_with, + flatc_args = flatc_args, + reflection_name = reflection_name, + reflection_visibility = visibility, + ) + native.cc_library( + name = name, + hdrs = output_headers, + srcs = output_headers, + features = [ + "-parse_headers", + ], + deps = [ + "@flatbuffers//:runtime_cc", + ], + includes = ["."], + linkstatic = 1, + visibility = visibility, + compatible_with = compatible_with, + ) + + # A filegroup for the `srcs`. That is, all the schema files for this + # Flatbuffer set. + native.filegroup( + name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name), + srcs = srcs, + visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility, + compatible_with = compatible_with, + ) + +# Custom provider to track dependencies transitively. +FlatbufferInfo = provider( + fields = { + "transitive_srcs": "flatbuffer schema definitions.", + }, +) + +def _flatbuffer_schemas_aspect_impl(target, ctx): + _ignore = [target] + transitive_srcs = depset() + if hasattr(ctx.rule.attr, "deps"): + for dep in ctx.rule.attr.deps: + if FlatbufferInfo in dep: + transitive_srcs = depset(dep[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs]) + if hasattr(ctx.rule.attr, "srcs"): + for src in ctx.rule.attr.srcs: + if FlatbufferInfo in src: + transitive_srcs = depset(src[FlatbufferInfo].transitive_srcs, transitive = [transitive_srcs]) + for f in src.files: + if f.extension == "fbs": + transitive_srcs = depset([f], transitive = [transitive_srcs]) + return [FlatbufferInfo(transitive_srcs = transitive_srcs)] + +# An aspect that runs over all dependencies and transitively collects +# flatbuffer schema files. +_flatbuffer_schemas_aspect = aspect( + attr_aspects = [ + "deps", + "srcs", + ], + implementation = _flatbuffer_schemas_aspect_impl, +) + +# Rule to invoke the flatbuffer compiler. +def _gen_flatbuffer_srcs_impl(ctx): + outputs = ctx.attr.outputs + include_paths = ctx.attr.include_paths + if ctx.attr.no_includes: + no_includes_statement = ["--no-includes"] + else: + no_includes_statement = [] + + # Need to generate all files in a directory. + if not outputs: + outputs = [ctx.actions.declare_directory("{}_all".format(ctx.attr.name))] + output_directory = outputs[0].path + else: + outputs = [ctx.actions.declare_file(output) for output in outputs] + output_directory = outputs[0].dirname + + deps = depset(ctx.files.srcs + ctx.files.deps, transitive = [ + dep[FlatbufferInfo].transitive_srcs + for dep in ctx.attr.deps + if FlatbufferInfo in dep + ]) + + include_paths_cmd_line = [] + for s in include_paths: + include_paths_cmd_line.extend(["-I", s]) + + for src in ctx.files.srcs: + ctx.actions.run( + inputs = deps, + outputs = outputs, + executable = ctx.executable._flatc, + arguments = [ + ctx.attr.language_flag, + "-o", + output_directory, + # Allow for absolute imports and referencing of generated files. + "-I", + "./", + "-I", + ctx.genfiles_dir.path, + "-I", + ctx.bin_dir.path, + ] + no_includes_statement + + include_paths_cmd_line + [ + "--no-union-value-namespacing", + "--gen-object-api", + src.path, + ], + progress_message = "Generating flatbuffer files for {}:".format(src), + ) + return [ + DefaultInfo(files = depset(outputs)), + ] + +_gen_flatbuffer_srcs = rule( + _gen_flatbuffer_srcs_impl, + attrs = { + "srcs": attr.label_list( + allow_files = [".fbs"], + mandatory = True, + ), + "outputs": attr.string_list( + default = [], + mandatory = False, + ), + "deps": attr.label_list( + default = [], + mandatory = False, + aspects = [_flatbuffer_schemas_aspect], + ), + "include_paths": attr.string_list( + default = [], + mandatory = False, + ), + "language_flag": attr.string( + mandatory = True, + ), + "no_includes": attr.bool( + default = False, + mandatory = False, + ), + "_flatc": attr.label( + default = Label("@flatbuffers//:flatc"), + executable = True, + cfg = "host", + ), + }, + output_to_genfiles = True, +) + +def _concat_flatbuffer_py_srcs_impl(ctx): + # Merge all generated python files. The files are concatenated and the + # import statements are removed. Finally we import the flatbuffer runtime + # library. + command = "echo 'import flatbuffers\n' > %s; " + command += "for f in $(find %s -name '*.py'); do cat $f | sed '/import flatbuffers/d' >> %s; done " + ctx.actions.run_shell( + inputs = ctx.attr.deps[0].files, + outputs = [ctx.outputs.out], + command = command % ( + ctx.outputs.out.path, + ctx.attr.deps[0].files.to_list()[0].path, + ctx.outputs.out.path, + ), + ) + +_concat_flatbuffer_py_srcs = rule( + _concat_flatbuffer_py_srcs_impl, + attrs = { + "deps": attr.label_list(mandatory = True), + }, + output_to_genfiles = True, + outputs = {"out": "%{name}.py"}, +) + +def flatbuffer_py_library( + name, + srcs, + deps = [], + include_paths = []): + """A py_library with the generated reader/writers for the given schema. + + This rule assumes that the schema files define non-conflicting names, so that + they can be merged in a single file. This is e.g. the case if only a single + namespace is used. + The rule call the flatbuffer compiler for all schema files and merges the + generated python files into a single file that is wrapped in a py_library. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files. (required) + deps: List of dependencies. + include_paths: Optional, list of paths the includes files can be found in. + """ + all_srcs = "{}_srcs".format(name) + _gen_flatbuffer_srcs( + name = all_srcs, + srcs = srcs, + language_flag = "--python", + deps = deps, + include_paths = include_paths, + ) + all_srcs_no_include = "{}_srcs_no_include".format(name) + _gen_flatbuffer_srcs( + name = all_srcs_no_include, + srcs = srcs, + language_flag = "--python", + deps = deps, + no_includes = True, + include_paths = include_paths, + ) + concat_py_srcs = "{}_generated".format(name) + _concat_flatbuffer_py_srcs( + name = concat_py_srcs, + deps = [ + ":{}".format(all_srcs_no_include), + ], + ) + native.py_library( + name = name, + srcs = [ + ":{}".format(concat_py_srcs), + ], + srcs_version = "PY2AND3", + deps = deps, + ) + +def flatbuffer_java_library( + name, + srcs, + custom_package = "", + package_prefix = "", + include_paths = DEFAULT_INCLUDE_PATHS, + flatc_args = DEFAULT_FLATC_ARGS, + visibility = None): + """A java library with the generated reader/writers for the given flatbuffer definitions. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files including all includes. (required) + custom_package: Package name of generated Java files. If not specified + namespace in the schema files will be used. (optional) + package_prefix: like custom_package, but prefixes to the existing + namespace. (optional) + include_paths: List of paths that includes files can be found in. (optional) + flatc_args: List of additional arguments to pass to flatc. (optional) + visibility: Visibility setting for the java_library rule. (optional) + """ + out_srcjar = "java_%s_all.srcjar" % name + flatbuffer_java_srcjar( + name = "%s_srcjar" % name, + srcs = srcs, + out = out_srcjar, + custom_package = custom_package, + flatc_args = flatc_args, + include_paths = include_paths, + package_prefix = package_prefix, + ) + + native.filegroup( + name = "%s.srcjar" % name, + srcs = [out_srcjar], + ) + + native.java_library( + name = name, + srcs = [out_srcjar], + javacopts = ["-source 7 -target 7"], + deps = [ + "@flatbuffers//:runtime_java", + ], + visibility = visibility, + ) + +def flatbuffer_java_srcjar( + name, + srcs, + out, + custom_package = "", + package_prefix = "", + include_paths = DEFAULT_INCLUDE_PATHS, + flatc_args = DEFAULT_FLATC_ARGS): + """Generate flatbuffer Java source files. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files including all includes. (required) + out: Output file name. (required) + custom_package: Package name of generated Java files. If not specified + namespace in the schema files will be used. (optional) + package_prefix: like custom_package, but prefixes to the existing + namespace. (optional) + include_paths: List of paths that includes files can be found in. (optional) + flatc_args: List of additional arguments to pass to flatc. (optional) + """ + command_fmt = """set -e + tmpdir=$(@D) + schemas=$$tmpdir/schemas + java_root=$$tmpdir/java + rm -rf $$schemas + rm -rf $$java_root + mkdir -p $$schemas + mkdir -p $$java_root + + for src in $(SRCS); do + dest=$$schemas/$$src + rm -rf $$(dirname $$dest) + mkdir -p $$(dirname $$dest) + if [ -z "{custom_package}" ] && [ -z "{package_prefix}" ]; then + cp -f $$src $$dest + else + if [ -z "{package_prefix}" ]; then + sed -e "s/namespace\\s.*/namespace {custom_package};/" $$src > $$dest + else + sed -e "s/namespace \\([^;]\\+\\);/namespace {package_prefix}.\\1;/" $$src > $$dest + fi + fi + done + + flatc_arg_I="-I $$tmpdir/schemas" + for include_path in {include_paths}; do + flatc_arg_I="$$flatc_arg_I -I $$schemas/$$include_path" + done + + flatc_additional_args= + for arg in {flatc_args}; do + flatc_additional_args="$$flatc_additional_args $$arg" + done + + for src in $(SRCS); do + $(location {flatc_path}) $$flatc_arg_I --java $$flatc_additional_args -o $$java_root $$schemas/$$src + done + + $(location {zip_files}) -export_zip_path=$@ -file_directory=$$java_root + """ + genrule_cmd = command_fmt.format( + package_name = native.package_name(), + custom_package = custom_package, + package_prefix = package_prefix, + flatc_path = flatc_path, + zip_files = zip_files, + include_paths = " ".join(include_paths), + flatc_args = " ".join(flatc_args), + ) + + native.genrule( + name = name, + srcs = srcs, + outs = [out], + tools = [flatc_path, zip_files], + cmd = genrule_cmd, + ) + +def flatbuffer_android_library( + name, + srcs, + custom_package = "", + package_prefix = "", + include_paths = DEFAULT_INCLUDE_PATHS, + flatc_args = DEFAULT_FLATC_ARGS, + visibility = None): + """An android_library with the generated reader/writers for the given flatbuffer definitions. + + Args: + name: Rule name. (required) + srcs: List of source .fbs files including all includes. (required) + custom_package: Package name of generated Java files. If not specified + namespace in the schema files will be used. (optional) + package_prefix: like custom_package, but prefixes to the existing + namespace. (optional) + include_paths: List of paths that includes files can be found in. (optional) + flatc_args: List of additional arguments to pass to flatc. (optional) + visibility: Visibility setting for the android_library rule. (optional) + """ + out_srcjar = "android_%s_all.srcjar" % name + flatbuffer_java_srcjar( + name = "%s_srcjar" % name, + srcs = srcs, + out = out_srcjar, + custom_package = custom_package, + flatc_args = flatc_args, + include_paths = include_paths, + package_prefix = package_prefix, + ) + + native.filegroup( + name = "%s.srcjar" % name, + srcs = [out_srcjar], + ) + + # To support org.checkerframework.dataflow.qual.Pure. + checkerframework_annotations = [ + "@org_checkerframework_qual", + ] if "--java-checkerframework" in flatc_args else [] + + android_library( + name = name, + srcs = [out_srcjar], + javacopts = ["-source 7 -target 7"], + visibility = visibility, + deps = [ + "@flatbuffers//:runtime_android", + ] + checkerframework_annotations, + ) diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl new file mode 100644 index 00000000..dea463f2 --- /dev/null +++ b/third_party/flatbuffers/workspace.bzl @@ -0,0 +1,19 @@ +"""Loads the Flatbuffers library, used by TF Lite.""" + +load("//third_party:repo.bzl", "third_party_http_archive") + +def repo(): + third_party_http_archive( + name = "flatbuffers", + strip_prefix = "flatbuffers-1.12.0", + sha256 = "62f2223fb9181d1d6338451375628975775f7522185266cd5296571ac152bc45", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.12.0.tar.gz", + "https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz", + ], + build_file = "//third_party/flatbuffers:BUILD.bazel", + delete = ["build_defs.bzl"], + link_files = { + "//third_party/flatbuffers:build_defs.bzl": "build_defs.bzl", + }, + ) diff --git a/third_party/gflags/BUILD b/third_party/gflags/BUILD new file mode 100644 index 00000000..82bab3ff --- /dev/null +++ b/third_party/gflags/BUILD @@ -0,0 +1 @@ +# This empty BUILD file is required to make Bazel treat this directory as a package. diff --git a/third_party/gflags/fix_android_pthread_link.patch b/third_party/gflags/fix_android_pthread_link.patch new file mode 100644 index 00000000..9a0b3511 --- /dev/null +++ b/third_party/gflags/fix_android_pthread_link.patch @@ -0,0 +1,32 @@ +diff --git a/BUILD b/BUILD +index 0a5c9eb..d836578 100644 +--- a/BUILD ++++ b/BUILD +@@ -6,6 +6,11 @@ licenses(["notice"]) + + exports_files(["src/gflags_complections.sh", "COPYING.txt"]) + ++config_setting( ++ name = "android", ++ values = {"crosstool_top": "//external:android/crosstool"}, ++) ++ + load(":bazel/gflags.bzl", "gflags_sources", "gflags_library") + (hdrs, srcs) = gflags_sources(namespace=["gflags", "google"]) + gflags_library(hdrs=hdrs, srcs=srcs, threads=0) +diff --git a/bazel/gflags.bzl b/bazel/gflags.bzl +index cd0edad..5c1d8b5 100644 +--- a/bazel/gflags.bzl ++++ b/bazel/gflags.bzl +@@ -77,7 +77,10 @@ def gflags_library(hdrs=[], srcs=[], threads=1): + ] + linkopts = [] + if threads: +- linkopts.append("-lpthread") ++ linkopts += select({ ++ "//:android": [], ++ "//conditions:default": ["-lpthread"], ++ }) + else: + name += "_nothreads" + copts.append("-DNO_THREADS")
\ No newline at end of file diff --git a/third_party/gflags/workspace.bzl b/third_party/gflags/workspace.bzl new file mode 100644 index 00000000..194a9d3f --- /dev/null +++ b/third_party/gflags/workspace.bzl @@ -0,0 +1,16 @@ +"""Loads the GFlags repo and patch it with android linkopt fix.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +def repo(): + http_archive( + name = "com_github_gflags_gflags", + sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e", + strip_prefix = "gflags-2.2.1", + urls = [ + "http://mirror.tensorflow.org/github.com/gflags/gflags/archive/v2.2.1.tar.gz", + "https://github.com/gflags/gflags/archive/v2.2.1.tar.gz", + ], + patches = ["@//third_party/gflags:fix_android_pthread_link.patch"], + patch_args = ["-p1"], + ) diff --git a/third_party/google_toolbox_for_mac.BUILD b/third_party/google_toolbox_for_mac.BUILD new file mode 100644 index 00000000..8d7fecf3 --- /dev/null +++ b/third_party/google_toolbox_for_mac.BUILD @@ -0,0 +1,22 @@ +# Description: +# A collection of source from different Google projects that may be of use to +# developers working other Mac projects. +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +exports_files( + ["UnitTest-Info.plist"], + visibility = ["//visibility:public"], +) + +objc_library( + name = "GTM_Defines", + hdrs = ["GTMDefines.h"], + includes = ["."], + visibility = ["//visibility:public"], +) diff --git a/third_party/icu.BUILD b/third_party/icu.BUILD new file mode 100644 index 00000000..7749dda0 --- /dev/null +++ b/third_party/icu.BUILD @@ -0,0 +1,97 @@ +"""Builds ICU library.""" + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files([ + "icu4c/LICENSE", + "icu4j/main/shared/licenses/LICENSE", +]) + +cc_library( + name = "headers", + hdrs = glob(["icu4c/source/common/unicode/*.h"]), + includes = [ + "icu4c/source/common", + ], + deps = [ + ], +) + +cc_library( + name = "common", + hdrs = glob(["icu4c/source/common/unicode/*.h"]), + includes = [ + "icu4c/source/common", + ], + deps = [ + ":icuuc", + ], +) + +alias( + name = "nfkc", + actual = ":common", +) + +alias( + name = "nfkc_cf", + actual = ":common", +) + +cc_library( + name = "icuuc", + srcs = glob( + [ + "icu4c/source/common/*.c", + "icu4c/source/common/*.cpp", + "icu4c/source/stubdata/*.cpp", + ], + ), + hdrs = glob([ + "icu4c/source/common/*.h", + ]), + copts = [ + "-DU_COMMON_IMPLEMENTATION", + ] + select({ + ":android": [ + "-fdata-sections", + "-DU_HAVE_NL_LANGINFO_CODESET=0", + "-Wno-deprecated-declarations", + ], + ":apple": [ + "-Wno-shorten-64-to-32", + "-Wno-unused-variable", + ], + ":windows": [ + "/utf-8", + "/DLOCALE_ALLOW_NEUTRAL_NAMES=0", + ], + "//conditions:default": [], + }), + tags = ["requires-rtti"], + visibility = [ + "//visibility:private", + ], + deps = [ + ":headers", + ], +) + +config_setting( + name = "android", + values = {"crosstool_top": "//external:android/crosstool"}, +) + +config_setting( + name = "apple", + values = {"cpu": "darwin"}, +) + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, +) diff --git a/third_party/libyuv.BUILD b/third_party/libyuv.BUILD new file mode 100644 index 00000000..4b39a8c0 --- /dev/null +++ b/third_party/libyuv.BUILD @@ -0,0 +1,25 @@ +# Description: +# The libyuv package provides implementation yuv image conversion, rotation +# and scaling. + +licenses(["notice"]) # BSD license + +exports_files(["LICENSE"]) + +cc_library( + name = "libyuv", + srcs = glob( + [ + "source/*.cc", + "include/libyuv/*.h", + ], + ), + hdrs = [ + "include/libyuv.h", + "include/libyuv/compare.h", + "include/libyuv/convert.h", + "include/libyuv/video_common.h", + ], + includes = ["include"], + visibility = ["//visibility:public"], +) diff --git a/third_party/libzip.BUILD b/third_party/libzip.BUILD new file mode 100644 index 00000000..b69ccf41 --- /dev/null +++ b/third_party/libzip.BUILD @@ -0,0 +1,189 @@ +package( + default_visibility = ["//visibility:public"], +) + +load("@org_tensorflow_lite_support//tensorflow_lite_support/tools:build_rules/expand_template.bzl", "cmake_substitutions", "expand_template") + +_CMAKE_VARIABLES = { + "INT16_T_LIBZIP": 2, + "INT32_T_LIBZIP": 4, + "INT64_T_LIBZIP": 8, + "INT8_T_LIBZIP": 1, + "INT_LIBZIP": 4, + "LIBZIP_TYPES_INCLUDE": "#include <stdint.h>", + "LONG_LIBZIP": 8, + "LONG_LONG_LIBZIP": 8, + "PACKAGE_VERSION": "1.5.1", + "PACKAGE_VERSION_MAJOR": "1", + "PACKAGE_VERSION_MICRO": "1", + "PACKAGE_VERSION_MINOR": "5", + "SHORT_LIBZIP": 2, + "SIZEOF_OFF_T": 8, + "SIZE_T_LIBZIP": 8, + "SSIZE_T_LIBZIP": 8, + "UINT16_T_LIBZIP": 2, + "UINT32_T_LIBZIP": 4, + "UINT64_T_LIBZIP": 8, + "UINT8_T_LIBZIP": 1, + "__INT16_LIBZIP": None, + "__INT32_LIBZIP": None, + "__INT64_LIBZIP": None, + "__INT8_LIBZIP": None, +} + +_CMAKE_VARIABLES.update(dict([ + ( + "ZIP_{sign}INT{size}_T".format( + sign = sign.upper(), + size = size, + ), + "{sign}int{size}_t".format( + sign = sign.lower(), + size = size, + ), + ) + for sign in ("U", "") + for size in (8, 16, 32, 64) +])) + +_SUBSTITUTIONS = { + "@PACKAGE@": "libzip", + "@VERSION@": "1.5.1", # Keep in sync with actual package! +} + +_DEFINES = { + "HAVE_CLONEFILE": False, + "HAVE_COMMONCRYPTO": False, + "HAVE_CRYPTO": False, + "HAVE_DIRENT_H": False, + "HAVE_FICLONERANGE": False, + "HAVE_FILENO": True, + "HAVE_FSEEK": True, + "HAVE_FSEEKO": True, + "HAVE_FTELLO": True, + "HAVE_FTS_H": True, + "HAVE_GETPROGNAME": False, + "HAVE_GNUTLS": False, + "HAVE_LIBBZ2": False, + "HAVE_MKSTEMP": True, + "HAVE_NDIR_H": False, + "HAVE_OPEN": True, + "HAVE_OPENSSL": False, + "HAVE_SETMODE": False, + "HAVE_SHARED": True, + "HAVE_SNPRINTF": True, + "HAVE_SSIZE_T_LIBZIP": True, + "HAVE_STDBOOL_H": True, + "HAVE_STRCASECMP": True, + "HAVE_STRDUP": True, + "HAVE_STRICMP": False, + "HAVE_STRINGS_H": True, + "HAVE_STRTOLL": True, + "HAVE_STRTOULL": True, + "HAVE_STRUCT_TM_TM_ZONE": False, + "HAVE_SYS_DIR_H": False, + "HAVE_SYS_NDIR_H": False, + "HAVE_UNISTD_H": True, + "HAVE__CHMOD": False, + "HAVE__CLOSE": False, + "HAVE__DUP": False, + "HAVE__FDOPEN": False, + "HAVE__FILENO": False, + "HAVE__OPEN": False, + "HAVE__SETMODE": False, + "HAVE__SNPRINTF": False, + "HAVE__STRDUP": False, + "HAVE__STRICMP": False, + "HAVE__STRTOI64": False, + "HAVE__STRTOUI64": False, + "HAVE__UMASK": False, + "HAVE__UNLINK": False, + "HAVE___PROGNAME": False, + "WORDS_BIGENDIAN": False, +} + +_DEFINES.update(dict([( + key, + value != None, +) for key, value in _CMAKE_VARIABLES.items()])) + +_SUBSTITUTIONS.update(cmake_substitutions( + defines = _DEFINES, + vars = _CMAKE_VARIABLES, +)) + +expand_template( + name = "config_h", + out = "config.h", + substitutions = _SUBSTITUTIONS, + template = "cmake-config.h.in", +) + +_VARS = { + "LIBZIP_TYPES_INCLUDE": "#include <stdint.h>", + "PACKAGE_VERSION": "1.5.1", + "PACKAGE_VERSION_MAJOR": "1", + "PACKAGE_VERSION_MICRO": "1", + "PACKAGE_VERSION_MINOR": "5", +} + +_VARS.update(dict([ + ( + "ZIP_{sign}INT{size}_T".format( + sign = sign.upper(), + size = size, + ), + "{sign}int{size}_t".format( + sign = sign.lower(), + size = size, + ), + ) + for sign in ("U", "") + for size in (8, 16, 32, 64) +])) + +expand_template( + name = "zipconf_h", + out = "lib/zipconf.h", + substitutions = cmake_substitutions( + defines = { + "LIBZIP_VERSION": True, + "LIBZIP_VERSION_MAJOR": True, + "LIBZIP_VERSION_MICRO": True, + "LIBZIP_VERSION_MINOR": True, + "ZIP_STATIC": False, + }, + vars = _VARS, + ), + template = "cmake-zipconf.h.in", +) + +cc_library( + name = "zip", + srcs = glob( + [ + "lib/*.c", + "lib/*.h", + ], + exclude = [ + "lib/*win32*", + "lib/zip_random_uwp.c", + "lib/*crypto*", + "lib/*aes*", + "lib/*bzip2*", + ], + ) + [ + "config.h", + ], + hdrs = [ + "lib/zip.h", + "lib/zipconf.h", + ], + copts = [ + "-DHAVE_CONFIG_H", + ], + includes = ["lib"], + deps = [ + "@zlib", + ], +) diff --git a/third_party/py/BUILD b/third_party/py/BUILD new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/third_party/py/BUILD diff --git a/third_party/py/BUILD.tpl b/third_party/py/BUILD.tpl new file mode 100644 index 00000000..cc0e013b --- /dev/null +++ b/third_party/py/BUILD.tpl @@ -0,0 +1,31 @@ +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +# Point both runtimes to the same python binary to ensure we always +# use the python binary specified by ./configure.py script. +load("@bazel_tools//tools/python:toolchain.bzl", "py_runtime_pair") + +py_runtime( + name = "py2_runtime", + interpreter_path = "%{PYTHON_BIN_PATH}", + python_version = "PY2", +) + +py_runtime( + name = "py3_runtime", + interpreter_path = "%{PYTHON_BIN_PATH}", + python_version = "PY3", +) + +py_runtime_pair( + name = "py_runtime_pair", + py2_runtime = ":py2_runtime", + py3_runtime = ":py3_runtime", +) + +toolchain( + name = "py_toolchain", + toolchain = ":py_runtime_pair", + toolchain_type = "@bazel_tools//tools/python:toolchain_type", +) diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl new file mode 100644 index 00000000..6601d7f2 --- /dev/null +++ b/third_party/py/python_configure.bzl @@ -0,0 +1,71 @@ +"""Repository rule for Python autoconfiguration. + +`python_configure` depends on the following environment variables: + + * `PYTHON_BIN_PATH`: location of python binary. +""" + +_PYTHON_BIN_PATH = "PYTHON_BIN_PATH" + +def _tpl(repository_ctx, tpl, substitutions = {}, out = None): + if not out: + out = tpl + repository_ctx.template( + out, + Label("//third_party/py:%s.tpl" % tpl), + substitutions, + ) + +def _fail(msg): + """Output failure message when auto configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _get_python_bin(repository_ctx): + """Gets the python bin path.""" + python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH) + if python_bin != None: + return python_bin + python_bin_path = repository_ctx.which("python") + if python_bin_path != None: + return str(python_bin_path) + _fail("Cannot find python in PATH, please make sure " + + "python is installed and add its directory in PATH, or --define " + + "%s='/something/else'.\nPATH=%s" % ( + _PYTHON_BIN_PATH, + repository_ctx.os.environ.get("PATH", ""), + )) + +def _create_local_python_repository(repository_ctx): + """Creates the repository containing files set up to build with Python.""" + python_bin = _get_python_bin(repository_ctx) + _tpl(repository_ctx, "BUILD", { + "%{PYTHON_BIN_PATH}": python_bin, + }) + +def _python_autoconf_impl(repository_ctx): + """Implementation of the python_autoconf repository rule.""" + _create_local_python_repository(repository_ctx) + +python_configure = repository_rule( + implementation = _python_autoconf_impl, + environ = [ + _PYTHON_BIN_PATH, + ], +) +"""Detects and configures the local Python toolchain. + +Add the following to your WORKSPACE FILE: + +```python +load("//third_party/py:python_configure.bzl", "python_configure") + +python_configure(name = "local_config_py_toolchain") + +register_toolchains("@local_config_py_toolchain//:py_toolchain") +``` + +Args: + name: A unique name for this workspace rule. +""" diff --git a/third_party/pybind11.BUILD b/third_party/pybind11.BUILD new file mode 100644 index 00000000..2f1ada61 --- /dev/null +++ b/third_party/pybind11.BUILD @@ -0,0 +1,25 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "pybind11", + hdrs = glob( + include = [ + "include/pybind11/*.h", + "include/pybind11/detail/*.h", + ], + exclude = [ + "include/pybind11/common.h", + "include/pybind11/eigen.h", + ], + ), + copts = [ + "-fexceptions", + "-Wno-undefined-inline", + "-Wno-pragma-once-outside-header", + ], + includes = ["include"], + strip_include_prefix = "include", + deps = [ + "@org_tensorflow//third_party/python_runtime:headers", + ], +) diff --git a/third_party/python_runtime/BUILD b/third_party/python_runtime/BUILD new file mode 100644 index 00000000..2a160919 --- /dev/null +++ b/third_party/python_runtime/BUILD @@ -0,0 +1,8 @@ +licenses(["notice"]) # New BSD, Python Software Foundation + +package(default_visibility = ["//visibility:public"]) + +alias( + name = "headers", + actual = "@local_config_python//:python_headers", +) diff --git a/third_party/repo.bzl b/third_party/repo.bzl new file mode 100644 index 00000000..c9c6a834 --- /dev/null +++ b/third_party/repo.bzl @@ -0,0 +1,152 @@ +# Copyright 2020 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for defining TensorFlow Lite Support Bazel dependencies.""" + +_SINGLE_URL_WHITELIST = [] + +def _is_windows(ctx): + return ctx.os.name.lower().find("windows") != -1 + +def _wrap_bash_cmd(ctx, cmd): + if _is_windows(ctx): + bazel_sh = _get_env_var(ctx, "BAZEL_SH") + if not bazel_sh: + fail("BAZEL_SH environment variable is not set") + cmd = [bazel_sh, "-l", "-c", " ".join(["\"%s\"" % s for s in cmd])] + return cmd + +def _get_env_var(ctx, name): + if name in ctx.os.environ: + return ctx.os.environ[name] + else: + return None + +# Checks if we should use the system lib instead of the bundled one +def _use_system_lib(ctx, name): + syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS") + if syslibenv: + for n in syslibenv.strip().split(","): + if n.strip() == name: + return True + return False + +# Executes specified command with arguments and calls 'fail' if it exited with +# non-zero code +def _execute_and_check_ret_code(repo_ctx, cmd_and_args): + result = repo_ctx.execute(cmd_and_args, timeout = 60) + if result.return_code != 0: + fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" + + "Stderr: {3}").format( + " ".join([str(x) for x in cmd_and_args]), + result.return_code, + result.stdout, + result.stderr, + )) + +# Apply a patch_file to the repository root directory +# Runs 'patch -p1' on both Windows and Unix. +def _apply_patch(ctx, patch_file): + patch_command = ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)] + cmd = _wrap_bash_cmd(ctx, patch_command) + _execute_and_check_ret_code(ctx, cmd) + +def _apply_delete(ctx, paths): + for path in paths: + if path.startswith("/"): + fail("refusing to rm -rf path starting with '/': " + path) + if ".." in path: + fail("refusing to rm -rf path containing '..': " + path) + cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths]) + _execute_and_check_ret_code(ctx, cmd) + +def _third_party_http_archive(ctx): + """Downloads and creates Bazel repos for dependencies. + + This is a swappable replacement for both http_archive() and + new_http_archive() that offers some additional features. It also helps + ensure best practices are followed. + """ + if ("mirror.tensorflow.org" not in ctx.attr.urls[0] and + (len(ctx.attr.urls) < 2 and + ctx.attr.name not in _SINGLE_URL_WHITELIST.to_list())): + fail("third_party_http_archive(urls) must have redundant URLs. The " + + "mirror.tensorflow.org URL must be present and it must come first. " + + "Even if you don't have permission to mirror the file, please " + + "put the correctly formatted mirror URL there anyway, because " + + "someone will come along shortly thereafter and mirror the file.") + + use_syslib = _use_system_lib(ctx, ctx.attr.name) + + # Use "BUILD.bazel" to avoid conflict with third party projects that contain a + # file or directory called "BUILD" + buildfile_path = ctx.path("BUILD.bazel") + + if use_syslib: + if ctx.attr.system_build_file == None: + fail("Bazel was configured with TF_SYSTEM_LIBS to use a system " + + "library for %s, but no system build file for %s was configured. " + + "Please add a system_build_file attribute to the repository rule" + + "for %s." % (ctx.attr.name, ctx.attr.name, ctx.attr.name)) + ctx.symlink(Label(ctx.attr.system_build_file), buildfile_path) + + else: + ctx.download_and_extract( + ctx.attr.urls, + "", + ctx.attr.sha256, + ctx.attr.type, + ctx.attr.strip_prefix, + ) + if ctx.attr.delete: + _apply_delete(ctx, ctx.attr.delete) + if ctx.attr.patch_file != None: + _apply_patch(ctx, ctx.attr.patch_file) + ctx.symlink(Label(ctx.attr.build_file), buildfile_path) + + link_dict = {} + if use_syslib: + link_dict.update(ctx.attr.system_link_files) + + for internal_src, external_dest in ctx.attr.link_files.items(): + # if syslib and link exists in both, use the system one + if external_dest not in link_dict.values(): + link_dict[internal_src] = external_dest + + for internal_src, external_dest in link_dict.items(): + ctx.symlink(Label(internal_src), ctx.path(external_dest)) + +# For link_files, specify each dict entry as: +# "//path/to/source:file": "localfile" +third_party_http_archive = repository_rule( + attrs = { + "sha256": attr.string(mandatory = True), + "urls": attr.string_list( + mandatory = True, + allow_empty = False, + ), + "strip_prefix": attr.string(), + "type": attr.string(), + "delete": attr.string_list(), + "build_file": attr.string(mandatory = True), + "system_build_file": attr.string(mandatory = False), + "patch_file": attr.label(), + "link_files": attr.string_dict(), + "system_link_files": attr.string_dict(), + }, + environ = [ + "TF_SYSTEM_LIBS", + ], + implementation = _third_party_http_archive, +) diff --git a/third_party/six.BUILD b/third_party/six.BUILD new file mode 100644 index 00000000..a1b2f7b2 --- /dev/null +++ b/third_party/six.BUILD @@ -0,0 +1,14 @@ +# Description: +# Six provides simple utilities for wrapping over differences between Python 2 +# and Python 3. + +licenses(["notice"]) # MIT + +exports_files(["LICENSE"]) + +py_library( + name = "six", + srcs = ["six.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], +) diff --git a/third_party/stblib.BUILD b/third_party/stblib.BUILD new file mode 100644 index 00000000..f1c361ac --- /dev/null +++ b/third_party/stblib.BUILD @@ -0,0 +1,26 @@ +# Description: +# Single-file C++ image decoding and encoding libraries + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # MIT license + +exports_files(["LICENSE"]) + +cc_library( + name = "stb_image", + hdrs = ["stb_image.h"], + copts = [ + "-Wno-unused-function", + "$(STACK_FRAME_UNLIMITED)", + ], + includes = ["."], +) + +cc_library( + name = "stb_image_write", + hdrs = ["stb_image_write.h"], + includes = ["."], +) diff --git a/third_party/tensorflow/BUILD b/third_party/tensorflow/BUILD new file mode 100644 index 00000000..ac039c46 --- /dev/null +++ b/third_party/tensorflow/BUILD @@ -0,0 +1 @@ +# placeholder to make the directory a bazel package. diff --git a/third_party/tensorflow/BUILD.tpl b/third_party/tensorflow/BUILD.tpl new file mode 100644 index 00000000..095021ed --- /dev/null +++ b/third_party/tensorflow/BUILD.tpl @@ -0,0 +1,18 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "tf_header_lib", + hdrs = [":tf_header_include"], + includes = ["include"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "libtensorflow_framework", + srcs = [":libtensorflow_framework.so"], + visibility = ["//visibility:public"], +) + +%{TF_HEADER_GENRULE} +%{TF_SHARED_LIBRARY_GENRULE} + diff --git a/third_party/tensorflow/tf_configure.bzl b/third_party/tensorflow/tf_configure.bzl new file mode 100644 index 00000000..32826255 --- /dev/null +++ b/third_party/tensorflow/tf_configure.bzl @@ -0,0 +1,224 @@ +"""Setup TensorFlow as external dependency""" + +_TF_HEADER_DIR = "TF_HEADER_DIR" +_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR" +_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME" + +def _tpl(repository_ctx, tpl, substitutions = {}, out = None): + if not out: + out = tpl + repository_ctx.template( + out, + Label("//third_party/tensorflow:%s.tpl" % tpl), + substitutions, + ) + +def _fail(msg): + """Output failure message when auto configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _is_windows(repository_ctx): + """Returns true if the host operating system is windows.""" + os_name = repository_ctx.os.name.lower() + if os_name.find("windows") != -1: + return True + return False + +def _execute( + repository_ctx, + cmdline, + error_msg = None, + error_details = None, + empty_stdout_fine = False): + """Executes an arbitrary shell command. + + Helper for executes an arbitrary shell command. + + Args: + repository_ctx: the repository_ctx object. + cmdline: list of strings, the command to execute. + error_msg: string, a summary of the error if the command fails. + error_details: string, details about the error or steps to fix it. + empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise + it's an error. + + Returns: + The result of repository_ctx.execute(cmdline). + """ + result = repository_ctx.execute(cmdline) + if result.stderr or not (empty_stdout_fine or result.stdout): + _fail("\n".join([ + error_msg.strip() if error_msg else "Repository command failed", + result.stderr.strip(), + error_details if error_details else "", + ])) + return result + +def _read_dir(repository_ctx, src_dir): + """Returns a string with all files in a directory. + + Finds all files inside a directory, traversing subfolders and following + symlinks. The returned string contains the full path of all files + separated by line breaks. + + Args: + repository_ctx: the repository_ctx object. + src_dir: directory to find files from. + + Returns: + A string of all files inside the given dir. + """ + if _is_windows(repository_ctx): + src_dir = src_dir.replace("/", "\\") + find_result = _execute( + repository_ctx, + ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], + empty_stdout_fine = True, + ) + + # src_files will be used in genrule.outs where the paths must + # use forward slashes. + result = find_result.stdout.replace("\\", "/") + else: + find_result = _execute( + repository_ctx, + ["find", src_dir, "-follow", "-type", "f"], + empty_stdout_fine = True, + ) + result = find_result.stdout + return result + +def _genrule(genrule_name, command, outs): + """Returns a string with a genrule. + + Genrule executes the given command and produces the given outputs. + + Args: + genrule_name: A unique name for genrule target. + command: The command to run. + outs: A list of files generated by this rule. + + Returns: + A genrule target. + """ + return ( + "genrule(\n" + + ' name = "' + + genrule_name + '",\n' + + " outs = [\n" + + outs + + "\n ],\n" + + ' cmd = """\n' + + command + + '\n """,\n' + + ")\n" + ) + +def _norm_path(path): + """Returns a path with '/' and remove the trailing slash.""" + path = path.replace("\\", "/") + if path[-1] == "/": + path = path[:-1] + return path + +def _symlink_genrule_for_dir( + repository_ctx, + src_dir, + dest_dir, + genrule_name, + src_files = [], + dest_files = [], + tf_pip_dir_rename_pair = []): + """Returns a genrule to symlink(or copy if on Windows) a set of files. + + If src_dir is passed, files will be read from the given directory; otherwise + we assume files are in src_files and dest_files. + + Args: + repository_ctx: the repository_ctx object. + src_dir: source directory. + dest_dir: directory to create symlink in. + genrule_name: genrule name. + src_files: list of source files instead of src_dir. + dest_files: list of corresonding destination files. + tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to + replace. For example, in TF pip package, the source code is under + "tensorflow_core", and we might want to replace it with + "tensorflow" to match the header includes. + Returns: + genrule target that creates the symlinks. + """ + + # Check that tf_pip_dir_rename_pair has the right length + tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair) + if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2: + _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len) + + if src_dir != None: + src_dir = _norm_path(src_dir) + dest_dir = _norm_path(dest_dir) + files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) + + # Create a list with the src_dir stripped to use for outputs. + if tf_pip_dir_rename_pair_len: + dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines() + else: + dest_files = files.replace(src_dir, "").splitlines() + src_files = files.splitlines() + command = [] + outs = [] + for i in range(len(dest_files)): + if dest_files[i] != "": + # If we have only one file to link we do not want to use the dest_dir, as + # $(@D) will include the full path to the file. + dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] + + # Copy the headers to create a sandboxable setup. + cmd = "cp -f" + command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) + outs.append(' "' + dest_dir + dest_files[i] + '",') + dest_dir = "abc" + genrule = _genrule( + genrule_name, + " && ".join(command), + "\n".join(outs), + ) + return genrule + +def _tf_pip_impl(repository_ctx): + tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR] + tf_header_rule = _symlink_genrule_for_dir( + repository_ctx, + tf_header_dir, + "include", + "tf_header_include", + tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"], + ) + + tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR] + tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME] + tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name) + tf_shared_library_rule = _symlink_genrule_for_dir( + repository_ctx, + None, + "", + "libtensorflow_framework.so", + [tf_shared_library_path], + ["_pywrap_tensorflow_internal.lib" if _is_windows(repository_ctx) else "libtensorflow_framework.so"], + ) + + _tpl(repository_ctx, "BUILD", { + "%{TF_HEADER_GENRULE}": tf_header_rule, + "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule, + }) + +tf_configure = repository_rule( + implementation = _tf_pip_impl, + environ = [ + _TF_HEADER_DIR, + _TF_SHARED_LIBRARY_DIR, + _TF_SHARED_LIBRARY_NAME, + ], +) diff --git a/third_party/tensorflow_lite_ios_build.patch b/third_party/tensorflow_lite_ios_build.patch new file mode 100644 index 00000000..786e46bc --- /dev/null +++ b/third_party/tensorflow_lite_ios_build.patch @@ -0,0 +1,40 @@ +diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD +similarity index 97% +rename from tensorflow/lite/experimental/ios/BUILD.apple +rename to tensorflow/lite/experimental/ios/BUILD +index cce0c4df..49eba35f 100644 +--- a/tensorflow/lite/experimental/ios/BUILD.apple ++++ b/tensorflow/lite/experimental/ios/BUILD +@@ -22,8 +22,7 @@ sh_binary( + "hide_symbols_with_allowlist.sh", + ], + visibility = [ +- "//tensorflow/lite:__subpackages__", +- "//tensorflow_lite_support:__subpackages__", ++ "//visibility:public", + ], + ) + +diff --git a/tensorflow/lite/experimental/ios/ios.bzl b/tensorflow/lite/experimental/ios/ios.bzl +index 63747eb8..07bcb49d 100644 +--- a/tensorflow/lite/experimental/ios/ios.bzl ++++ b/tensorflow/lite/experimental/ios/ios.bzl +@@ -60,7 +60,7 @@ def tflite_ios_static_framework( + "BUNDLE_NAME=\"" + bundle_name + "\" " + + "ALLOWLIST_FILE_PATH=\"$(location " + allowlist_symbols_file + ")\" " + + "OUTPUT=\"$(OUTS)\" " + +- "\"$(location //tensorflow/lite/experimental/ios:hide_symbols_with_allowlist)\"") ++ "\"$(location @org_tensorflow//tensorflow/lite/experimental/ios:hide_symbols_with_allowlist)\"") + + native.genrule( + name = name, +@@ -68,7 +68,7 @@ def tflite_ios_static_framework( + outs = [name + ".zip"], + cmd = cmd, + tools = [ +- "//tensorflow/lite/experimental/ios:hide_symbols_with_allowlist", ++ "@org_tensorflow//tensorflow/lite/experimental/ios:hide_symbols_with_allowlist", + ], + ) + + diff --git a/third_party/tensorflow_text_remove_tf_deps.patch b/third_party/tensorflow_text_remove_tf_deps.patch new file mode 100644 index 00000000..f7b86f9f --- /dev/null +++ b/third_party/tensorflow_text_remove_tf_deps.patch @@ -0,0 +1,32 @@ +diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD +index bdca365..1c20eae 100644 +--- a/tensorflow_text/core/kernels/BUILD ++++ b/tensorflow_text/core/kernels/BUILD +@@ -209,8 +209,12 @@ cc_library( + name = "regex_split", + srcs = ["regex_split.cc"], + hdrs = ["regex_split.h"], +- deps = OSS_DEPS + [ ++ deps = [ + # absl/strings dep ++ "@com_google_absl//absl/container:inlined_vector", ++ "@com_google_absl//absl/strings", ++ "@com_google_absl//absl/types:optional", ++ "@com_google_absl//absl/types:span", + "@com_google_re2//:re2", + ], + ) +@@ -437,8 +441,12 @@ cc_library( + name = "wordpiece_tokenizer", + srcs = ["wordpiece_tokenizer.cc"], + hdrs = ["wordpiece_tokenizer.h"], +- deps = OSS_DEPS + [ ++ deps = [ + # absl/strings dep ++ "@com_google_absl//absl/container:inlined_vector", ++ "@com_google_absl//absl/strings", ++ "@com_google_absl//absl/types:optional", ++ "@com_google_absl//absl/types:span", + "@icu//:common", + ], + )
\ No newline at end of file diff --git a/third_party/toolchains/java/BUILD b/third_party/toolchains/java/BUILD new file mode 100644 index 00000000..83722915 --- /dev/null +++ b/third_party/toolchains/java/BUILD @@ -0,0 +1,18 @@ +# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1 +# TensorFlow still targets Java 1.7 (See JAVACOPTS in tensorflow/java/build_defs.bzl) +# which doesn't support "-parameters" flag. Starting from Java 11 (default since Bazel +# 0.29.1), a warning message will be thrown if "-parameters" is passed. If "-Werror" also exists, +# the compiling action will fail. To workaround this, we override the misc value of +# the default java toolchain to remove "-parameters" flag. +load("@bazel_tools//tools/jdk:default_java_toolchain.bzl", "default_java_toolchain") + +licenses(["notice"]) + +default_java_toolchain( + name = "tf_java_toolchain", + misc = [ + "-XDskipDuplicateBridges=true", + "-g", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/utf.BUILD b/third_party/utf.BUILD new file mode 100644 index 00000000..0a21fd78 --- /dev/null +++ b/third_party/utf.BUILD @@ -0,0 +1,38 @@ +cc_library( + name = "utf", + srcs = [ + "libutf/rune.c", + "libutf/runestrcat.c", + "libutf/runestrchr.c", + "libutf/runestrcmp.c", + "libutf/runestrcpy.c", + "libutf/runestrdup.c", + "libutf/runestrecpy.c", + "libutf/runestrlen.c", + "libutf/runestrncat.c", + "libutf/runestrncmp.c", + "libutf/runestrncpy.c", + "libutf/runestrrchr.c", + "libutf/runestrstr.c", + "libutf/runetype.c", + "libutf/utfecpy.c", + "libutf/utflen.c", + "libutf/utfnlen.c", + "libutf/utfrrune.c", + "libutf/utfrune.c", + "libutf/utfutf.c", + ], + hdrs = [ + "libutf/plan9.h", + "libutf/utf.h", + "libutf/utfdef.h", + ], + copts = [ + "-Wno-parentheses", + ], + includes = [ + ".", + "libutf", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD new file mode 100644 index 00000000..275782e0 --- /dev/null +++ b/third_party/zlib.BUILD @@ -0,0 +1,39 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "zlib", + srcs = [ + "adler32.c", + "compress.c", + "crc32.c", + "crc32.h", + "deflate.c", + "deflate.h", + "gzclose.c", + "gzguts.h", + "gzlib.c", + "gzread.c", + "gzwrite.c", + "infback.c", + "inffast.c", + "inffast.h", + "inffixed.h", + "inflate.c", + "inflate.h", + "inftrees.c", + "inftrees.h", + "trees.c", + "trees.h", + "uncompr.c", + "zutil.c", + "zutil.h", + ], + hdrs = [ + "zconf.h", + "zlib.h", + ], + copts = ["-Wno-implicit-function-declaration"], + includes = ["."], +) |