diff options
author | fban <fban@google.com> | 2022-11-16 03:38:49 +0000 |
---|---|---|
committer | fban <fban@google.com> | 2022-11-16 03:38:49 +0000 |
commit | df43288acfb34d7029bd50ad3869c9d50c29afae (patch) | |
tree | 34aa1d83e6ccc821a23dda8063487f59384211e6 | |
parent | 0320ce8084721a759e7d2e6965d17540eeec8c36 (diff) | |
download | tflite-support-df43288acfb34d7029bd50ad3869c9d50c29afae.tar.gz |
Adds missing OpResolver calls to the BertNLClassifier JNI layer.
This is a necessary prerequisite in order to support a custom
op-resolver for Rubidium.
Test: atest OnDeviceClassifierTest
Bug: 238435760
Change-Id: I932153486e7bf561cd674342b868080c11b292e6
2 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc index 88a8a0af..4e4e999d 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc @@ -64,7 +64,6 @@ constexpr char kSeparator[] = "[SEP]"; constexpr int kTokenizerProcessUnitIndex = 0; } // namespace -// TODO(b/241507692) Add a unit test for a model with dynamic tensors. absl::Status BertNLClassifier::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) { auto* input_tensor_metadatas = diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc index aef82408..7781a4a8 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc @@ -14,11 +14,18 @@ limitations under the License. ==============================================================================*/ #include <jni.h> - +#include "tensorflow/lite/op_resolver.h" #include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h" #include "tensorflow_lite_support/cc/utils/jni_utils.h" #include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h" +namespace tflite { +namespace task { +// To be provided by a link-time library +extern std::unique_ptr<OpResolver> CreateOpResolver(); +} // namespace task +} // namespace tflite + namespace { using ::tflite::support::utils::GetMappedFileBuffer; @@ -41,7 +48,8 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByte JNIEnv* env, jclass thiz, jobject model_buffer) { auto model = GetMappedFileBuffer(env, model_buffer); tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status = - BertNLClassifier::CreateFromBuffer(model.data(), model.size()); + BertNLClassifier::CreateFromBuffer(model.data(), model.size(), + tflite::task::CreateOpResolver()); if (status.ok()) { return reinterpret_cast<jlong>(status->release()); } else { @@ -56,7 +64,7 @@ extern "C" JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor( JNIEnv* env, jclass thiz, jint fd) { tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status = - BertNLClassifier::CreateFromFd(fd); + BertNLClassifier::CreateFromFd(fd, tflite::task::CreateOpResolver()); if (status.ok()) { return reinterpret_cast<jlong>(status->release()); } else { |