aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-06-14 13:52:51 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-06-14 13:52:51 +0000
commitb6a3f195652064d8f4f776f34e82811de7790a8a (patch)
tree7477ac98623e59634aaacf509972e7645e0a7b7e
parent7868065c1247bbd145c16f07ef80b24ae9685816 (diff)
parent6398339e056464f627207764f18c41a16043ccea (diff)
downloadtflite-support-b6a3f195652064d8f4f776f34e82811de7790a8a.tar.gz
Snap for 8720775 from 6398339e056464f627207764f18c41a16043ccea to mainline-sdkext-release
Change-Id: Ib922de86cbd5383e352a0842f89212acf475dc8f
-rw-r--r--Android.bp53
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc4
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h3
-rw-r--r--tensorflow_lite_support/java/AndroidManifest.xml4
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java2
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java13
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java2
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java2
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java50
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java88
-rw-r--r--tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflitebin0 -> 25707538 bytes
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc7
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc5
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h2
-rw-r--r--tensorflow_lite_support/metadata/cc/metadata_extractor.cc9
-rw-r--r--tensorflow_lite_support/metadata/cc/metadata_extractor.h5
16 files changed, 244 insertions, 5 deletions
diff --git a/Android.bp b/Android.bp
index 41b981da..a0fb7f15 100644
--- a/Android.bp
+++ b/Android.bp
@@ -213,7 +213,7 @@ java_library {
}
cc_library_shared {
- name: "tflite_support_classifiers_native",
+ name: "libtflite_support_classifiers_native",
srcs: [
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc",
"tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc",
@@ -373,6 +373,57 @@ cc_test {
],
}
+android_test {
+ name: "TfliteSupportClassifierTests",
+ srcs: [
+ "tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java",
+ ],
+ asset_dirs: [
+ "tensorflow_lite_support/java/src/javatests/testdata/task/text",
+ ],
+ defaults: ["modules-utils-testable-device-config-defaults"],
+ manifest: "tensorflow_lite_support/java/AndroidManifest.xml",
+ sdk_version: "module_current",
+ min_sdk_version: "30",
+ static_libs: [
+ "androidx.test.core",
+ "tensorflowlite_java",
+ "truth-prebuilt",
+ "tflite_support_classifiers_java",
+ "tflite_support_test_utils_java",
+ ],
+ libs: [
+ "android.test.base",
+ "android.test.mock.stubs",
+ ],
+ test_suites: [
+ "general-tests",
+ ],
+ jni_libs: [
+ "libtflite_support_classifiers_native",
+ ],
+ aaptflags: [
+ // Avoid compression on tflite files as the Interpreter
+ // can not load compressed flat buffer formats.
+ // (*appt compresses all assets into the apk by default)
+ // See https://elinux.org/Android_aapt for more detail.
+ "-0 .tflite",
+ ],
+}
+
+java_library_static {
+ name: "tflite_support_test_utils_java",
+ sdk_version: "module_current",
+ min_sdk_version: "30",
+ srcs: [
+ "tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java",
+ ],
+ static_libs: [
+ "apache-commons-compress",
+ "guava",
+ ],
+}
+
cc_library_static {
name: "tflite_support_task_core_proto",
proto: {
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
index a1d4196e..5fafd17c 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
@@ -184,6 +184,10 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) {
return Infer(text).value();
}
+std::string NLClassifier::GetVersion() const {
+ return GetMetadataExtractor()->GetVersion();
+}
+
absl::Status NLClassifier::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
index 2a9573a1..05189dda 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
@@ -112,6 +112,9 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
// Performs classification on a string input, returns classified results.
std::vector<core::Category> Classify(const std::string& text);
+ // Gets the model version, or "NO_VERSION_INFO" in case there is no version.
+ std::string GetVersion() const;
+
protected:
static constexpr int kOutputTensorIndex = 0;
static constexpr int kOutputTensorLabelFileIndex = 0;
diff --git a/tensorflow_lite_support/java/AndroidManifest.xml b/tensorflow_lite_support/java/AndroidManifest.xml
index 14909296..c36eb383 100644
--- a/tensorflow_lite_support/java/AndroidManifest.xml
+++ b/tensorflow_lite_support/java/AndroidManifest.xml
@@ -2,4 +2,8 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.tensorflow.lite.support">
<uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/>
+ <instrumentation
+ android:name="androidx.test.runner.AndroidJUnitRunner"
+ android:targetPackage="org.tensorflow.lite.support" >
+ </instrumentation>
</manifest>
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
index 5b043a9f..cce42399 100644
--- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java
@@ -40,8 +40,8 @@ public final class Category {
return new Category(label, displayName, score);
}
- @UsedByReflection("TFLiteSupport/Task")
/** Constructs a {@link Category} object with an empty displayName. */
+ @UsedByReflection("TFLiteSupport/Task")
public Category(String label, float score) {
this(label, /*displayName=*/ "", score);
}
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
index 90bea370..0906cd03 100644
--- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
@@ -42,7 +42,7 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
* </ul>
*/
public class BertNLClassifier extends BaseTaskApi {
- private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
+ private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "tflite_support_classifiers_native";
/**
* Constructor to initialize the JNI with a pointer from C++.
@@ -122,12 +122,23 @@ public class BertNLClassifier extends BaseTaskApi {
return classifyNative(getNativeHandle(), text);
}
+ /**
+ * Gets the model version from the model metadata.
+ *
+ * @return The model version.
+ */
+ public String getVersion() {
+ return getVersionNative(getNativeHandle());
+ }
+
private static native long initJniWithByteBuffer(ByteBuffer modelBuffer);
private static native long initJniWithFileDescriptor(int fd);
private static native List<Category> classifyNative(long nativeHandle, String text);
+ private static native String getVersionNative(long nativeHandle);
+
@Override
protected void deinit(long nativeHandle) {
deinitJni(nativeHandle);
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
index 2bc20d8c..642fab5b 100644
--- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java
@@ -126,7 +126,7 @@ public class NLClassifier extends BaseTaskApi {
}
}
- private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni";
+ private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "tflite_support_classifiers_native";
/**
* Constructor to initialize the JNI with a pointer from C++.
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
index 76f562ef..7c019d21 100644
--- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
@@ -28,7 +28,7 @@ import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
/** Task API for BertQA models. */
public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer {
- private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
+ private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "tflite_support_classifiers_native";
private BertQuestionAnswerer(long nativeHandle) {
super(nativeHandle);
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java
new file mode 100644
index 00000000..4c6f369d
--- /dev/null
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java
@@ -0,0 +1,50 @@
+package org.tensorflow.lite.task.core;
+
+import android.content.Context;
+import android.content.res.AssetManager;
+
+import com.google.common.io.ByteStreams;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+/** Helper class for the Java test in Task Libary. */
+public final class TestUtils {
+
+ /**
+ * Loads the file and create a {@link File} object by reading a file from the asset directory.
+ * Simulates downloading or reading a file that's not precompiled with the app.
+ *
+ * @return a {@link File} object for the model.
+ */
+ public static File loadFile(Context context, String fileName) {
+ File target = new File(context.getFilesDir(), fileName);
+ try (InputStream is = context.getAssets().open(fileName);
+ FileOutputStream os = new FileOutputStream(target)) {
+ ByteStreams.copy(is, os);
+ } catch (IOException e) {
+ throw new AssertionError("Failed to load model file at " + fileName, e);
+ }
+ return target;
+ }
+
+ /**
+ * Reads a file into a direct {@link ByteBuffer} object from the asset directory.
+ *
+ * @return a {@link ByteBuffer} object for the file.
+ */
+ public static ByteBuffer loadToDirectByteBuffer(Context context, String fileName)
+ throws IOException {
+ AssetManager assetManager = context.getAssets();
+ InputStream inputStream = assetManager.open(fileName);
+ byte[] bytes = ByteStreams.toByteArray(inputStream);
+
+ ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder());
+ buffer.put(bytes);
+ return buffer;
+ }
+} \ No newline at end of file
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
new file mode 100644
index 00000000..a4c2f525
--- /dev/null
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
@@ -0,0 +1,88 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.text.nlclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import androidx.test.core.app.ApplicationProvider;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.junit.Test;
+import org.tensorflow.lite.support.label.Category;
+import org.tensorflow.lite.task.core.TestUtils;
+
+/** Test for {@link BertNLClassifier}. */
+public class BertNLClassifierTest {
+ private static final String MODEL_FILE = "bert_nl_classifier.tflite";
+
+ Category findCategoryWithLabel(List<Category> list, String label) {
+ return list.stream()
+ .filter(category -> label.equals(category.getLabel()))
+ .findAny()
+ .orElse(null);
+ }
+
+ @Test
+ public void createFromPath_verifyResults() throws IOException {
+ verifyResults(
+ BertNLClassifier.createFromFile(ApplicationProvider.getApplicationContext(),
+ MODEL_FILE));
+ }
+
+ @Test
+ public void createFromFile_verifyResults() throws IOException {
+ verifyResults(
+ BertNLClassifier.createFromFile(
+ TestUtils.loadFile(ApplicationProvider.getApplicationContext(),
+ MODEL_FILE)));
+ }
+
+ @Test
+ public void classify_succeedsWithModelFile() throws IOException {
+ verifyResults(
+ BertNLClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), MODEL_FILE));
+ }
+
+ @Test
+ public void classify_succeedsWithModelBuffer() throws IOException {
+ verifyResults(
+ BertNLClassifier.createFromBuffer(
+ TestUtils.loadToDirectByteBuffer(
+ ApplicationProvider.getApplicationContext(), MODEL_FILE)));
+ }
+
+ @Test
+ public void getVersion_succeedsWithVersionInMetadata() throws IOException {
+ BertNLClassifier classifier = BertNLClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), MODEL_FILE);
+
+ assertThat(classifier.getVersion()).isEqualTo("v1");
+ }
+
+ private void verifyResults(BertNLClassifier classifier) {
+ List<Category> negativeResults = classifier.classify("unflinchingly bleak and desperate");
+ assertThat(findCategoryWithLabel(negativeResults, "negative").getScore())
+ .isGreaterThan(findCategoryWithLabel(negativeResults, "positive").getScore());
+
+ List<Category> positiveResults =
+ classifier.classify("it's a charming and often affecting journey");
+ assertThat(findCategoryWithLabel(positiveResults, "positive").getScore())
+ .isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore());
+ }
+} \ No newline at end of file
diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite
new file mode 100644
index 00000000..97a32da4
--- /dev/null
+++ b/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite
Binary files differ
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
index 1edb3507..0866764e 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
@@ -27,6 +27,7 @@ using ::tflite::support::utils::kInvalidPointer;
using ::tflite::support::utils::ThrowException;
using ::tflite::task::text::nlclassifier::BertNLClassifier;
using ::tflite::task::text::nlclassifier::RunClassifier;
+using ::tflite::task::text::nlclassifier::GetVersionNative;
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni(
@@ -71,4 +72,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative(
return RunClassifier(env, native_handle, text);
}
+extern "C" JNIEXPORT jstring JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_getVersionNative(
+ JNIEnv* env, jclass clazz, jlong native_handle) {
+ return GetVersionNative(env, native_handle);
+}
+
} // namespace
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
index c358bee1..e6040674 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
@@ -50,6 +50,11 @@ jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) {
});
}
+jstring GetVersionNative(JNIEnv* env, jlong native_handle) {
+ auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle);
+ return env->NewStringUTF(nl_classifier->GetVersion().c_str());
+}
+
} // namespace nlclassifier
} // namespace text
} // namespace task
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h
index 2c59ab50..2c8fbc07 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h
@@ -25,6 +25,8 @@ namespace nlclassifier {
jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text);
+jstring GetVersionNative(JNIEnv* env, jlong native_handle);
+
} // namespace nlclassifier
} // namespace text
} // namespace task
diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
index 7cf1b1ca..93263a17 100644
--- a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
+++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
@@ -35,6 +35,7 @@ namespace metadata {
namespace {
constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
+constexpr char kNoVersionInfo[] = "NO_VERSION_INFO";
using ::absl::StatusCode;
using ::flatbuffers::Offset;
@@ -380,5 +381,13 @@ int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
return output_process_units == nullptr ? 0 : output_process_units->size();
}
+std::string ModelMetadataExtractor::GetVersion() const {
+ if (model_metadata_ == nullptr ||
+ model_metadata_->version() == nullptr) {
+ return kNoVersionInfo;
+ }
+ return model_metadata_->version()->str();
+}
+
} // namespace metadata
} // namespace tflite \ No newline at end of file
diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/tensorflow_lite_support/metadata/cc/metadata_extractor.h
index 9a278970..c73b09b5 100644
--- a/tensorflow_lite_support/metadata/cc/metadata_extractor.h
+++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
#define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_
+#include <string>
+
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
@@ -127,6 +129,9 @@ class ModelMetadataExtractor {
// case there is no output process unit or the index is out of range.
const tflite::ProcessUnit* GetOutputProcessUnit(int index) const;
+ // Gets the model version, or "NO_VERSION_INFO" in case there is no version.
+ std::string GetVersion() const;
+
// Gets the count of output process units. In particular, 0 is returned when
// there is no output process units.
int GetOutputProcessUnitsCount() const;