diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-06-14 13:52:51 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-06-14 13:52:51 +0000 |
commit | b6a3f195652064d8f4f776f34e82811de7790a8a (patch) | |
tree | 7477ac98623e59634aaacf509972e7645e0a7b7e | |
parent | 7868065c1247bbd145c16f07ef80b24ae9685816 (diff) | |
parent | 6398339e056464f627207764f18c41a16043ccea (diff) | |
download | tflite-support-b6a3f195652064d8f4f776f34e82811de7790a8a.tar.gz |
Snap for 8720775 from 6398339e056464f627207764f18c41a16043ccea to mainline-sdkext-release
Change-Id: Ib922de86cbd5383e352a0842f89212acf475dc8f
16 files changed, 244 insertions, 5 deletions
@@ -213,7 +213,7 @@ java_library { } cc_library_shared { - name: "tflite_support_classifiers_native", + name: "libtflite_support_classifiers_native", srcs: [ "tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc", "tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni.cc", @@ -373,6 +373,57 @@ cc_test { ], } +android_test { + name: "TfliteSupportClassifierTests", + srcs: [ + "tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java", + ], + asset_dirs: [ + "tensorflow_lite_support/java/src/javatests/testdata/task/text", + ], + defaults: ["modules-utils-testable-device-config-defaults"], + manifest: "tensorflow_lite_support/java/AndroidManifest.xml", + sdk_version: "module_current", + min_sdk_version: "30", + static_libs: [ + "androidx.test.core", + "tensorflowlite_java", + "truth-prebuilt", + "tflite_support_classifiers_java", + "tflite_support_test_utils_java", + ], + libs: [ + "android.test.base", + "android.test.mock.stubs", + ], + test_suites: [ + "general-tests", + ], + jni_libs: [ + "libtflite_support_classifiers_native", + ], + aaptflags: [ + // Avoid compression on tflite files as the Interpreter + // can not load compressed flat buffer formats. + // (*appt compresses all assets into the apk by default) + // See https://elinux.org/Android_aapt for more detail. + "-0 .tflite", + ], +} + +java_library_static { + name: "tflite_support_test_utils_java", + sdk_version: "module_current", + min_sdk_version: "30", + srcs: [ + "tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java", + ], + static_libs: [ + "apache-commons-compress", + "guava", + ], +} + cc_library_static { name: "tflite_support_task_core_proto", proto: { diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc index a1d4196e..5fafd17c 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc @@ -184,6 +184,10 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) { return Infer(text).value(); } +std::string NLClassifier::GetVersion() const { + return GetMetadataExtractor()->GetVersion(); +} + absl::Status NLClassifier::Preprocess( const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) { TfLiteTensor* input_tensor = FindTensorWithNameOrIndex( diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h index 2a9573a1..05189dda 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h @@ -112,6 +112,9 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, // Performs classification on a string input, returns classified results. std::vector<core::Category> Classify(const std::string& text); + // Gets the model version, or "NO_VERSION_INFO" in case there is no version. + std::string GetVersion() const; + protected: static constexpr int kOutputTensorIndex = 0; static constexpr int kOutputTensorLabelFileIndex = 0; diff --git a/tensorflow_lite_support/java/AndroidManifest.xml b/tensorflow_lite_support/java/AndroidManifest.xml index 14909296..c36eb383 100644 --- a/tensorflow_lite_support/java/AndroidManifest.xml +++ b/tensorflow_lite_support/java/AndroidManifest.xml @@ -2,4 +2,8 @@ <manifest xmlns:android="http://schemas.android.com/apk/res/android" package="org.tensorflow.lite.support"> <uses-sdk android:minSdkVersion="19" android:targetSdkVersion="29"/> + <instrumentation + android:name="androidx.test.runner.AndroidJUnitRunner" + android:targetPackage="org.tensorflow.lite.support" > + </instrumentation> </manifest> diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java index 5b043a9f..cce42399 100644 --- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/Category.java @@ -40,8 +40,8 @@ public final class Category { return new Category(label, displayName, score); } - @UsedByReflection("TFLiteSupport/Task") /** Constructs a {@link Category} object with an empty displayName. */ + @UsedByReflection("TFLiteSupport/Task") public Category(String label, float score) { this(label, /*displayName=*/ "", score); } diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java index 90bea370..0906cd03 100644 --- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java @@ -42,7 +42,7 @@ import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider; * </ul> */ public class BertNLClassifier extends BaseTaskApi { - private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; + private static final String BERT_NL_CLASSIFIER_NATIVE_LIBNAME = "tflite_support_classifiers_native"; /** * Constructor to initialize the JNI with a pointer from C++. @@ -122,12 +122,23 @@ public class BertNLClassifier extends BaseTaskApi { return classifyNative(getNativeHandle(), text); } + /** + * Gets the model version from the model metadata. + * + * @return The model version. + */ + public String getVersion() { + return getVersionNative(getNativeHandle()); + } + private static native long initJniWithByteBuffer(ByteBuffer modelBuffer); private static native long initJniWithFileDescriptor(int fd); private static native List<Category> classifyNative(long nativeHandle, String text); + private static native String getVersionNative(long nativeHandle); + @Override protected void deinit(long nativeHandle) { deinitJni(nativeHandle); diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java index 2bc20d8c..642fab5b 100644 --- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/NLClassifier.java @@ -126,7 +126,7 @@ public class NLClassifier extends BaseTaskApi { } } - private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "task_text_jni"; + private static final String NL_CLASSIFIER_NATIVE_LIBNAME = "tflite_support_classifiers_native"; /** * Constructor to initialize the JNI with a pointer from C++. diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java index 76f562ef..7c019d21 100644 --- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java @@ -28,7 +28,7 @@ import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider; /** Task API for BertQA models. */ public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer { - private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni"; + private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "tflite_support_classifiers_native"; private BertQuestionAnswerer(long nativeHandle) { super(nativeHandle); diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java new file mode 100644 index 00000000..4c6f369d --- /dev/null +++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/core/TestUtils.java @@ -0,0 +1,50 @@ +package org.tensorflow.lite.task.core; + +import android.content.Context; +import android.content.res.AssetManager; + +import com.google.common.io.ByteStreams; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** Helper class for the Java test in Task Libary. */ +public final class TestUtils { + + /** + * Loads the file and create a {@link File} object by reading a file from the asset directory. + * Simulates downloading or reading a file that's not precompiled with the app. + * + * @return a {@link File} object for the model. + */ + public static File loadFile(Context context, String fileName) { + File target = new File(context.getFilesDir(), fileName); + try (InputStream is = context.getAssets().open(fileName); + FileOutputStream os = new FileOutputStream(target)) { + ByteStreams.copy(is, os); + } catch (IOException e) { + throw new AssertionError("Failed to load model file at " + fileName, e); + } + return target; + } + + /** + * Reads a file into a direct {@link ByteBuffer} object from the asset directory. + * + * @return a {@link ByteBuffer} object for the file. + */ + public static ByteBuffer loadToDirectByteBuffer(Context context, String fileName) + throws IOException { + AssetManager assetManager = context.getAssets(); + InputStream inputStream = assetManager.open(fileName); + byte[] bytes = ByteStreams.toByteArray(inputStream); + + ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder()); + buffer.put(bytes); + return buffer; + } +}
\ No newline at end of file diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java new file mode 100644 index 00000000..a4c2f525 --- /dev/null +++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java @@ -0,0 +1,88 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.task.text.nlclassifier; + +import static com.google.common.truth.Truth.assertThat; + +import androidx.test.core.app.ApplicationProvider; + +import java.io.IOException; +import java.util.List; + +import org.junit.Test; +import org.tensorflow.lite.support.label.Category; +import org.tensorflow.lite.task.core.TestUtils; + +/** Test for {@link BertNLClassifier}. */ +public class BertNLClassifierTest { + private static final String MODEL_FILE = "bert_nl_classifier.tflite"; + + Category findCategoryWithLabel(List<Category> list, String label) { + return list.stream() + .filter(category -> label.equals(category.getLabel())) + .findAny() + .orElse(null); + } + + @Test + public void createFromPath_verifyResults() throws IOException { + verifyResults( + BertNLClassifier.createFromFile(ApplicationProvider.getApplicationContext(), + MODEL_FILE)); + } + + @Test + public void createFromFile_verifyResults() throws IOException { + verifyResults( + BertNLClassifier.createFromFile( + TestUtils.loadFile(ApplicationProvider.getApplicationContext(), + MODEL_FILE))); + } + + @Test + public void classify_succeedsWithModelFile() throws IOException { + verifyResults( + BertNLClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), MODEL_FILE)); + } + + @Test + public void classify_succeedsWithModelBuffer() throws IOException { + verifyResults( + BertNLClassifier.createFromBuffer( + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE))); + } + + @Test + public void getVersion_succeedsWithVersionInMetadata() throws IOException { + BertNLClassifier classifier = BertNLClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), MODEL_FILE); + + assertThat(classifier.getVersion()).isEqualTo("v1"); + } + + private void verifyResults(BertNLClassifier classifier) { + List<Category> negativeResults = classifier.classify("unflinchingly bleak and desperate"); + assertThat(findCategoryWithLabel(negativeResults, "negative").getScore()) + .isGreaterThan(findCategoryWithLabel(negativeResults, "positive").getScore()); + + List<Category> positiveResults = + classifier.classify("it's a charming and often affecting journey"); + assertThat(findCategoryWithLabel(positiveResults, "positive").getScore()) + .isGreaterThan(findCategoryWithLabel(positiveResults, "negative").getScore()); + } +}
\ No newline at end of file diff --git a/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite b/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite Binary files differnew file mode 100644 index 00000000..97a32da4 --- /dev/null +++ b/tensorflow_lite_support/java/src/javatests/testdata/task/text/bert_nl_classifier.tflite diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc index 1edb3507..0866764e 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc @@ -27,6 +27,7 @@ using ::tflite::support::utils::kInvalidPointer; using ::tflite::support::utils::ThrowException; using ::tflite::task::text::nlclassifier::BertNLClassifier; using ::tflite::task::text::nlclassifier::RunClassifier; +using ::tflite::task::text::nlclassifier::GetVersionNative; extern "C" JNIEXPORT void JNICALL Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni( @@ -71,4 +72,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative( return RunClassifier(env, native_handle, text); } +extern "C" JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_getVersionNative( + JNIEnv* env, jclass clazz, jlong native_handle) { + return GetVersionNative(env, native_handle); +} + } // namespace diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc index c358bee1..e6040674 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc @@ -50,6 +50,11 @@ jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) { }); } +jstring GetVersionNative(JNIEnv* env, jlong native_handle) { + auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle); + return env->NewStringUTF(nl_classifier->GetVersion().c_str()); +} + } // namespace nlclassifier } // namespace text } // namespace task diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h index 2c59ab50..2c8fbc07 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h @@ -25,6 +25,8 @@ namespace nlclassifier { jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text); +jstring GetVersionNative(JNIEnv* env, jlong native_handle); + } // namespace nlclassifier } // namespace text } // namespace task diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc index 7cf1b1ca..93263a17 100644 --- a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc +++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc @@ -35,6 +35,7 @@ namespace metadata { namespace { constexpr char kMetadataBufferName[] = "TFLITE_METADATA"; +constexpr char kNoVersionInfo[] = "NO_VERSION_INFO"; using ::absl::StatusCode; using ::flatbuffers::Offset; @@ -380,5 +381,13 @@ int ModelMetadataExtractor::GetOutputProcessUnitsCount() const { return output_process_units == nullptr ? 0 : output_process_units->size(); } +std::string ModelMetadataExtractor::GetVersion() const { + if (model_metadata_ == nullptr || + model_metadata_->version() == nullptr) { + return kNoVersionInfo; + } + return model_metadata_->version()->str(); +} + } // namespace metadata } // namespace tflite
\ No newline at end of file diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/tensorflow_lite_support/metadata/cc/metadata_extractor.h index 9a278970..c73b09b5 100644 --- a/tensorflow_lite_support/metadata/cc/metadata_extractor.h +++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ #define TENSORFLOW_LITE_SUPPORT_METADATA_CC_METADATA_EXTRACTOR_H_ +#include <string> + #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" @@ -127,6 +129,9 @@ class ModelMetadataExtractor { // case there is no output process unit or the index is out of range. const tflite::ProcessUnit* GetOutputProcessUnit(int index) const; + // Gets the model version, or "NO_VERSION_INFO" in case there is no version. + std::string GetVersion() const; + // Gets the count of output process units. In particular, 0 is returned when // there is no output process units. int GetOutputProcessUnitsCount() const; |