aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfban <fban@google.com>2022-11-16 03:38:49 +0000
committerfban <fban@google.com>2022-11-16 03:38:49 +0000
commitdf43288acfb34d7029bd50ad3869c9d50c29afae (patch)
tree34aa1d83e6ccc821a23dda8063487f59384211e6
parent0320ce8084721a759e7d2e6965d17540eeec8c36 (diff)
downloadtflite-support-df43288acfb34d7029bd50ad3869c9d50c29afae.tar.gz
Adds missing OpResolver calls to the BertNLClassifier JNI layer.
This is a necessary prerequisite in order to support a custom op-resolver for Rubidium. Test: atest OnDeviceClassifierTest Bug: 238435760 Change-Id: I932153486e7bf561cd674342b868080c11b292e6
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc1
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc14
2 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
index 88a8a0af..4e4e999d 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
@@ -64,7 +64,6 @@ constexpr char kSeparator[] = "[SEP]";
constexpr int kTokenizerProcessUnitIndex = 0;
} // namespace
-// TODO(b/241507692) Add a unit test for a model with dynamic tensors.
absl::Status BertNLClassifier::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
auto* input_tensor_metadatas =
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
index aef82408..7781a4a8 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
@@ -14,11 +14,18 @@ limitations under the License.
==============================================================================*/
#include <jni.h>
-
+#include "tensorflow/lite/op_resolver.h"
#include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
#include "tensorflow_lite_support/cc/utils/jni_utils.h"
#include "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h"
+namespace tflite {
+namespace task {
+// To be provided by a link-time library
+extern std::unique_ptr<OpResolver> CreateOpResolver();
+} // namespace task
+} // namespace tflite
+
namespace {
using ::tflite::support::utils::GetMappedFileBuffer;
@@ -41,7 +48,8 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByte
JNIEnv* env, jclass thiz, jobject model_buffer) {
auto model = GetMappedFileBuffer(env, model_buffer);
tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status =
- BertNLClassifier::CreateFromBuffer(model.data(), model.size());
+ BertNLClassifier::CreateFromBuffer(model.data(), model.size(),
+ tflite::task::CreateOpResolver());
if (status.ok()) {
return reinterpret_cast<jlong>(status->release());
} else {
@@ -56,7 +64,7 @@ extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor(
JNIEnv* env, jclass thiz, jint fd) {
tflite::support::StatusOr<std::unique_ptr<BertNLClassifier>> status =
- BertNLClassifier::CreateFromFd(fd);
+ BertNLClassifier::CreateFromFd(fd, tflite::task::CreateOpResolver());
if (status.ok()) {
return reinterpret_cast<jlong>(status->release());
} else {