diff options
author | Ben Nissan <bennissan@google.com> | 2022-07-13 19:36:16 +0000 |
---|---|---|
committer | Ben Nissan <bennissan@google.com> | 2022-07-14 19:45:03 +0000 |
commit | 5d1e591a33054b75a5214c75be68cc14877b31d2 (patch) | |
tree | fb25d55229be4360d8611a7fd6bce6b4baf57d1f | |
parent | 456bf98dcc1a98a825c0d34cb374ba8bbf73083a (diff) | |
download | tflite-support-5d1e591a33054b75a5214c75be68cc14877b31d2.tar.gz |
Use custom op resolver for NL classifiers
This CL implements a custom, minimal op resolver containing only the
TensorFlow ops required by NLClassifier and BertNLClassifier, and
replaces the BuiltInOpResolver currently used in these classes with
the new MinimalOpResolver.
Bug: 238435760
Test: atest TfliteSupportClassifierTests
[From AdServices] atest OnDeviceClassifierTest
Change-Id: Ifef01e79228ad369fb16a8f5e3a1b154e3f7998f
-rw-r--r-- | Android.bp | 2 | ||||
-rw-r--r-- | tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc | 61 |
2 files changed, 62 insertions, 1 deletions
@@ -218,7 +218,7 @@ cc_library_shared { "tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc", "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc", "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc", - "tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc", + "tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc", "tensorflow_lite_support/cc/utils/jni_utils.cc", ], shared_libs: ["liblog"], diff --git a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc new file mode 100644 index 00000000..31d693a8 --- /dev/null +++ b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc @@ -0,0 +1,61 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> + +#include "absl/memory/memory.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/op_resolver.h" + +namespace tflite { +namespace task { + +// Create a minimal MutableOpResolver to provide only +// the ops required by NLClassifier/BertNLClassifier. +std::unique_ptr<MutableOpResolver> CreateOpResolver() { + MutableOpResolver resolver; + resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE, + ::tflite::ops::builtin::Register_DEQUANTIZE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_RESHAPE, + ::tflite::ops::builtin::Register_RESHAPE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_GATHER, + ::tflite::ops::builtin::Register_GATHER()); + resolver.AddBuiltin(::tflite::BuiltinOperator_STRIDED_SLICE, + ::tflite::ops::builtin::Register_STRIDED_SLICE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_PAD, + ::tflite::ops::builtin::Register_PAD()); + resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION, + ::tflite::ops::builtin::Register_CONCATENATION()); + resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED, + ::tflite::ops::builtin::Register_FULLY_CONNECTED()); + resolver.AddBuiltin(::tflite::BuiltinOperator_CAST, + ::tflite::ops::builtin::Register_CAST()); + resolver.AddBuiltin(::tflite::BuiltinOperator_MUL, + ::tflite::ops::builtin::Register_MUL()); + resolver.AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + resolver.AddBuiltin(::tflite::BuiltinOperator_TRANSPOSE, + ::tflite::ops::builtin::Register_TRANSPOSE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_SPLIT, + ::tflite::ops::builtin::Register_SPLIT()); + resolver.AddBuiltin(::tflite::BuiltinOperator_PACK, + ::tflite::ops::builtin::Register_PACK()); + resolver.AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, + ::tflite::ops::builtin::Register_SOFTMAX()); + return absl::make_unique<MutableOpResolver>(resolver); +} + +} // namespace task +} // namespace tflite
\ No newline at end of file |