aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc')
-rw-r--r--tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc46
1 files changed, 37 insertions, 9 deletions
diff --git a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
index 31d693a8..32d1054d 100644
--- a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
+++ b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include <memory>
-#include "absl/memory/memory.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/op_resolver.h"
@@ -23,21 +22,15 @@ namespace tflite {
namespace task {
// Create a minimal MutableOpResolver to provide only
-// the ops required by NLClassifier/BertNLClassifier.
+// the ops required by the bert_nl_classifier and rb_model for BertNLClassifier.
std::unique_ptr<MutableOpResolver> CreateOpResolver() {
MutableOpResolver resolver;
- resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE,
- ::tflite::ops::builtin::Register_DEQUANTIZE());
resolver.AddBuiltin(::tflite::BuiltinOperator_RESHAPE,
::tflite::ops::builtin::Register_RESHAPE());
resolver.AddBuiltin(::tflite::BuiltinOperator_GATHER,
::tflite::ops::builtin::Register_GATHER());
resolver.AddBuiltin(::tflite::BuiltinOperator_STRIDED_SLICE,
::tflite::ops::builtin::Register_STRIDED_SLICE());
- resolver.AddBuiltin(::tflite::BuiltinOperator_PAD,
- ::tflite::ops::builtin::Register_PAD());
- resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
- ::tflite::ops::builtin::Register_CONCATENATION());
resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
::tflite::ops::builtin::Register_FULLY_CONNECTED());
resolver.AddBuiltin(::tflite::BuiltinOperator_CAST,
@@ -54,7 +47,42 @@ std::unique_ptr<MutableOpResolver> CreateOpResolver() {
::tflite::ops::builtin::Register_PACK());
resolver.AddBuiltin(::tflite::BuiltinOperator_SOFTMAX,
::tflite::ops::builtin::Register_SOFTMAX());
- return absl::make_unique<MutableOpResolver>(resolver);
+ resolver.AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS,
+ ::tflite::ops::builtin::Register_EXPAND_DIMS());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SHAPE,
+ ::tflite::ops::builtin::Register_SHAPE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_FILL,
+ ::tflite::ops::builtin::Register_FILL());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SUB,
+ ::tflite::ops::builtin::Register_SUB());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_MEAN,
+ ::tflite::ops::builtin::Register_MEAN());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
+ ::tflite::ops::builtin::Register_SQUARED_DIFFERENCE());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_RSQRT,
+ ::tflite::ops::builtin::Register_RSQRT());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_BATCH_MATMUL,
+ ::tflite::ops::builtin::Register_BATCH_MATMUL());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_GELU,
+ ::tflite::ops::builtin::Register_GELU());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_TANH,
+ ::tflite::ops::builtin::Register_TANH());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_LOGISTIC,
+ ::tflite::ops::builtin::Register_LOGISTIC());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_SLICE,
+ ::tflite::ops::builtin::Register_SLICE());
+ // Needed for the test bert_nl_classifier model.
+ resolver.AddBuiltin(::tflite::BuiltinOperator_PAD,
+ ::tflite::ops::builtin::Register_PAD());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
+ ::tflite::ops::builtin::Register_CONCATENATION());
+ resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
+ ::tflite::ops::builtin::Register_FULLY_CONNECTED(),
+ /*version=*/9);
+ resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE,
+ ::tflite::ops::builtin::Register_DEQUANTIZE(),
+ /*version=*/2);
+ return std::make_unique<MutableOpResolver>(resolver);
}
} // namespace task