diff options
Diffstat (limited to 'tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc')
-rw-r--r-- | tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc | 46 |
1 files changed, 37 insertions, 9 deletions
diff --git a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc index 31d693a8..32d1054d 100644 --- a/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc +++ b/tensorflow_lite_support/java/src/native/task/core/minimal_op_resolver.cc @@ -15,7 +15,6 @@ limitations under the License. #include <memory> -#include "absl/memory/memory.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/op_resolver.h" @@ -23,21 +22,15 @@ namespace tflite { namespace task { // Create a minimal MutableOpResolver to provide only -// the ops required by NLClassifier/BertNLClassifier. +// the ops required by the bert_nl_classifier and rb_model for BertNLClassifier. std::unique_ptr<MutableOpResolver> CreateOpResolver() { MutableOpResolver resolver; - resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE, - ::tflite::ops::builtin::Register_DEQUANTIZE()); resolver.AddBuiltin(::tflite::BuiltinOperator_RESHAPE, ::tflite::ops::builtin::Register_RESHAPE()); resolver.AddBuiltin(::tflite::BuiltinOperator_GATHER, ::tflite::ops::builtin::Register_GATHER()); resolver.AddBuiltin(::tflite::BuiltinOperator_STRIDED_SLICE, ::tflite::ops::builtin::Register_STRIDED_SLICE()); - resolver.AddBuiltin(::tflite::BuiltinOperator_PAD, - ::tflite::ops::builtin::Register_PAD()); - resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION, - ::tflite::ops::builtin::Register_CONCATENATION()); resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED, ::tflite::ops::builtin::Register_FULLY_CONNECTED()); resolver.AddBuiltin(::tflite::BuiltinOperator_CAST, @@ -54,7 +47,42 @@ std::unique_ptr<MutableOpResolver> CreateOpResolver() { ::tflite::ops::builtin::Register_PACK()); resolver.AddBuiltin(::tflite::BuiltinOperator_SOFTMAX, ::tflite::ops::builtin::Register_SOFTMAX()); - return absl::make_unique<MutableOpResolver>(resolver); + resolver.AddBuiltin(::tflite::BuiltinOperator_EXPAND_DIMS, + ::tflite::ops::builtin::Register_EXPAND_DIMS()); + resolver.AddBuiltin(::tflite::BuiltinOperator_SHAPE, + ::tflite::ops::builtin::Register_SHAPE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_FILL, + ::tflite::ops::builtin::Register_FILL()); + resolver.AddBuiltin(::tflite::BuiltinOperator_SUB, + ::tflite::ops::builtin::Register_SUB()); + resolver.AddBuiltin(::tflite::BuiltinOperator_MEAN, + ::tflite::ops::builtin::Register_MEAN()); + resolver.AddBuiltin(::tflite::BuiltinOperator_SQUARED_DIFFERENCE, + ::tflite::ops::builtin::Register_SQUARED_DIFFERENCE()); + resolver.AddBuiltin(::tflite::BuiltinOperator_RSQRT, + ::tflite::ops::builtin::Register_RSQRT()); + resolver.AddBuiltin(::tflite::BuiltinOperator_BATCH_MATMUL, + ::tflite::ops::builtin::Register_BATCH_MATMUL()); + resolver.AddBuiltin(::tflite::BuiltinOperator_GELU, + ::tflite::ops::builtin::Register_GELU()); + resolver.AddBuiltin(::tflite::BuiltinOperator_TANH, + ::tflite::ops::builtin::Register_TANH()); + resolver.AddBuiltin(::tflite::BuiltinOperator_LOGISTIC, + ::tflite::ops::builtin::Register_LOGISTIC()); + resolver.AddBuiltin(::tflite::BuiltinOperator_SLICE, + ::tflite::ops::builtin::Register_SLICE()); + // Needed for the test bert_nl_classifier model. + resolver.AddBuiltin(::tflite::BuiltinOperator_PAD, + ::tflite::ops::builtin::Register_PAD()); + resolver.AddBuiltin(::tflite::BuiltinOperator_CONCATENATION, + ::tflite::ops::builtin::Register_CONCATENATION()); + resolver.AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED, + ::tflite::ops::builtin::Register_FULLY_CONNECTED(), + /*version=*/9); + resolver.AddBuiltin(::tflite::BuiltinOperator_DEQUANTIZE, + ::tflite::ops::builtin::Register_DEQUANTIZE(), + /*version=*/2); + return std::make_unique<MutableOpResolver>(resolver); } } // namespace task |