aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Nissan <bennissan@google.com>2022-07-13 19:36:16 +0000
committerBen Nissan <bennissan@google.com>2022-07-14 19:45:03 +0000
commit5d1e591a33054b75a5214c75be68cc14877b31d2 (patch)
treefb25d55229be4360d8611a7fd6bce6b4baf57d1f
parent456bf98dcc1a98a825c0d34cb374ba8bbf73083a (diff)
downloadtflite-support-5d1e591a33054b75a5214c75be68cc14877b31d2.tar.gz
Use custom op resolver for NL classifiers
This CL implements a custom, minimal op resolver containing only the TensorFlow ops required by NLClassifier and BertNLClassifier, and replaces the BuiltInOpResolver currently used in these classes with the new MinimalOpResolver. Bug: 238435760 Test: atest TfliteSupportClassifierTests [From AdServices] atest OnDeviceClassifierTest Change-Id: Ifef01e79228ad369fb16a8f5e3a1b154e3f7998f
-rw-r--r--Android.bp2
-rw-r--r--tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc61
2 files changed, 62 insertions, 1 deletions
diff --git a/Android.bp b/Android.bp
index a0fb7f15..daed67eb 100644
--- a/Android.bp
+++ b/Android.bp
@@ -218,7 +218,7 @@ cc_library_shared {
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc",
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc",
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc",
- "tensorflow_lite_support/java/src/native/task/core/builtin_op_resolver.cc",
+ "tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc",
"tensorflow_lite_support/cc/utils/jni_utils.cc",
],
shared_libs: ["liblog"],
diff --git a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
new file mode 100644
index 00000000..31d693a8
--- /dev/null
+++ b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
@@ -0,0 +1,61 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/lite/kernels/builtin_op_kernels.h"
+#include "tensorflow/lite/op_resolver.h"
+
+namespace tflite {
+namespace task {
+
+// Create a minimal MutableOpResolver to provide only
+// the ops required by NLClassifier/BertNLClassifier.
+std::unique_ptr<MutableOpResolver> CreateOpResolver() {
+ MutableOpResolver resolver;
+ resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE,
+ ::tflite::ops::builtin::Register_DEQUANTIZE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_RESHAPE,
+ ::tflite::ops::builtin::Register_RESHAPE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_GATHER,
+ ::tflite::ops::builtin::Register_GATHER());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_STRIDED_SLICE,
+ ::tflite::ops::builtin::Register_STRIDED_SLICE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_PAD,
+ ::tflite::ops::builtin::Register_PAD());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
+ ::tflite::ops::builtin::Register_CONCATENATION());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
+ ::tflite::ops::builtin::Register_FULLY_CONNECTED());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_CAST,
+ ::tflite::ops::builtin::Register_CAST());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_MUL,
+ ::tflite::ops::builtin::Register_MUL());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_ADD,
+ ::tflite::ops::builtin::Register_ADD());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_TRANSPOSE,
+ ::tflite::ops::builtin::Register_TRANSPOSE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SPLIT,
+ ::tflite::ops::builtin::Register_SPLIT());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_PACK,
+ ::tflite::ops::builtin::Register_PACK());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SOFTMAX,
+ ::tflite::ops::builtin::Register_SOFTMAX());
+ return absl::make_unique<MutableOpResolver>(resolver);
+}
+
+} // namespace task
+} // namespace tflite \ No newline at end of file