diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-06-16 12:26:15 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-06-16 12:26:15 +0000 |
commit | f4198d228105ee9c6e0feaa6d7a7fe7df46546b9 (patch) | |
tree | 1fccb2b92c9b7a466967dd16e16a710ae3b436f8 | |
parent | 45970cd1fc37cba4d13c800a5d52514289ed6539 (diff) | |
parent | 456bf98dcc1a98a825c0d34cb374ba8bbf73083a (diff) | |
download | tflite-support-f4198d228105ee9c6e0feaa6d7a7fe7df46546b9.tar.gz |
Snap for 8734275 from 456bf98dcc1a98a825c0d34cb374ba8bbf73083a to mainline-resolv-releaseaml_res_330910000aml_res_330810000
Change-Id: I3f23ae31440e27ea13b5f2f6da228c74d24c1acc
10 files changed, 95 insertions, 27 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc index 5fafd17c..a75fe0ff 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc @@ -63,6 +63,7 @@ using ::tflite::task::core::PopulateTensor; namespace { constexpr int kRegexTokenizerInputTensorIndex = 0; constexpr int kRegexTokenizerProcessUnitIndex = 0; +constexpr char kNoVersionInfo[] = "NO_VERSION_INFO"; StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile( const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>* @@ -169,6 +170,11 @@ absl::Status NLClassifier::TrySetLabelFromMetadata( labels_vector_ = absl::make_unique<std::vector<std::string>>(LoadVocabFromBuffer( label_buffer.value().data(), label_buffer.value().size())); + if (associated_file->version() == nullptr) { + labels_version_ = kNoVersionInfo; + } else { + labels_version_ = associated_file->version()->str(); + } return absl::OkStatus(); } else { return CreateStatusWithPayload( @@ -184,8 +190,17 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) { return Infer(text).value(); } -std::string NLClassifier::GetVersion() const { - return GetMetadataExtractor()->GetVersion(); +std::string NLClassifier::GetModelVersion() const { + tflite::support::StatusOr<std::string> model_version = + GetMetadataExtractor()->GetModelVersion(); + if (model_version.ok()) { + return model_version.value(); + } + return kNoVersionInfo; +} + +std::string NLClassifier::GetLabelsVersion() const { + return labels_version_; } absl::Status NLClassifier::Preprocess( diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h index 05189dda..013b7d53 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h @@ -113,7 +113,10 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, std::vector<core::Category> Classify(const std::string& text); // Gets the model version, or "NO_VERSION_INFO" in case there is no version. - std::string GetVersion() const; + std::string GetModelVersion() const; + + // Gets the labels version, or "NO_VERSION_INFO" in case there is no version. + std::string GetLabelsVersion() const; protected: static constexpr int kOutputTensorIndex = 0; @@ -173,6 +176,9 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, // labels vector initialized from output tensor's associated file, if one // exists. std::unique_ptr<std::vector<std::string>> labels_vector_; + // labels version assigned from output tensor's associated file metadata, + // if one exists. + std::string labels_version_; std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_; }; diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java index 0906cd03..65cac78e 100644 --- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java +++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java @@ -123,12 +123,23 @@ public class BertNLClassifier extends BaseTaskApi { } /** - * Gets the model version from the model metadata. + * Gets the model version from the model metadata, + * or "NO_VERSION_INFO" in case there is no version. * * @return The model version. */ - public String getVersion() { - return getVersionNative(getNativeHandle()); + public String getModelVersion() { + return getModelVersionNative(getNativeHandle()); + } + + /** + * Gets the labels version from the model metadata, + * or "NO_VERSION_INFO" in case there is no version. + * + * @return The labels version. + */ + public String getLabelsVersion() { + return getLabelsVersionNative(getNativeHandle()); } private static native long initJniWithByteBuffer(ByteBuffer modelBuffer); @@ -137,7 +148,9 @@ public class BertNLClassifier extends BaseTaskApi { private static native List<Category> classifyNative(long nativeHandle, String text); - private static native String getVersionNative(long nativeHandle); + private static native String getModelVersionNative(long nativeHandle); + + private static native String getLabelsVersionNative(long nativeHandle); @Override protected void deinit(long nativeHandle) { diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java index a4c2f525..f0667d77 100644 --- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java +++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java @@ -68,11 +68,19 @@ public class BertNLClassifierTest { } @Test - public void getVersion_succeedsWithVersionInMetadata() throws IOException { + public void getModelVersion_succeedsWithVersionInMetadata() throws IOException { BertNLClassifier classifier = BertNLClassifier.createFromFile( ApplicationProvider.getApplicationContext(), MODEL_FILE); - assertThat(classifier.getVersion()).isEqualTo("v1"); + assertThat(classifier.getModelVersion()).isEqualTo("v1"); + } + + @Test + public void getLabelsVersion_succeedsWithNoVersionInMetadata() throws IOException { + BertNLClassifier classifier = BertNLClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), MODEL_FILE); + + assertThat(classifier.getLabelsVersion()).isEqualTo("NO_VERSION_INFO"); } private void verifyResults(BertNLClassifier classifier) { diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc index 0866764e..aef82408 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc @@ -27,7 +27,8 @@ using ::tflite::support::utils::kInvalidPointer; using ::tflite::support::utils::ThrowException; using ::tflite::task::text::nlclassifier::BertNLClassifier; using ::tflite::task::text::nlclassifier::RunClassifier; -using ::tflite::task::text::nlclassifier::GetVersionNative; +using ::tflite::task::text::nlclassifier::GetModelVersionNative; +using ::tflite::task::text::nlclassifier::GetLabelsVersionNative; extern "C" JNIEXPORT void JNICALL Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni( @@ -73,9 +74,15 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative( } extern "C" JNIEXPORT jstring JNICALL -Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_getVersionNative( +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_getModelVersionNative( JNIEnv* env, jclass clazz, jlong native_handle) { - return GetVersionNative(env, native_handle); + return GetModelVersionNative(env, native_handle); +} + +extern "C" JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_getLabelsVersionNative( + JNIEnv* env, jclass clazz, jlong native_handle) { + return GetLabelsVersionNative(env, native_handle); } } // namespace diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc index e6040674..6b002f71 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc @@ -50,9 +50,14 @@ jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) { }); } -jstring GetVersionNative(JNIEnv* env, jlong native_handle) { +jstring GetModelVersionNative(JNIEnv* env, jlong native_handle) { auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle); - return env->NewStringUTF(nl_classifier->GetVersion().c_str()); + return env->NewStringUTF(nl_classifier->GetModelVersion().c_str()); +} + +jstring GetLabelsVersionNative(JNIEnv* env, jlong native_handle) { + auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle); + return env->NewStringUTF(nl_classifier->GetLabelsVersion().c_str()); } } // namespace nlclassifier diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h index 2c8fbc07..c21eda70 100644 --- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h +++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h @@ -25,7 +25,9 @@ namespace nlclassifier { jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text); -jstring GetVersionNative(JNIEnv* env, jlong native_handle); +jstring GetModelVersionNative(JNIEnv* env, jlong native_handle); + +jstring GetLabelsVersionNative(JNIEnv* env, jlong native_handle); } // namespace nlclassifier } // namespace text diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc index 93263a17..c2d85bc0 100644 --- a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc +++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc @@ -35,7 +35,6 @@ namespace metadata { namespace { constexpr char kMetadataBufferName[] = "TFLITE_METADATA"; -constexpr char kNoVersionInfo[] = "NO_VERSION_INFO"; using ::absl::StatusCode; using ::flatbuffers::Offset; @@ -291,6 +290,23 @@ ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const { return it->second; } +tflite::support::StatusOr<std::string> +ModelMetadataExtractor::GetModelVersion() const { + if (model_metadata_ == nullptr) { + return CreateStatusWithPayload( + StatusCode::kFailedPrecondition, + "No model metadata", + TfLiteSupportStatus::kMetadataNotFoundError); + } + if (model_metadata_->version() == nullptr) { + return CreateStatusWithPayload( + StatusCode::kNotFound, + "No version in model metadata", + TfLiteSupportStatus::kMetadataNotFoundError); + } + return model_metadata_->version()->str(); +} + const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>* ModelMetadataExtractor::GetInputTensorMetadata() const { if (model_metadata_ == nullptr || @@ -381,13 +397,5 @@ int ModelMetadataExtractor::GetOutputProcessUnitsCount() const { return output_process_units == nullptr ? 0 : output_process_units->size(); } -std::string ModelMetadataExtractor::GetVersion() const { - if (model_metadata_ == nullptr || - model_metadata_->version() == nullptr) { - return kNoVersionInfo; - } - return model_metadata_->version()->str(); -} - } // namespace metadata } // namespace tflite
\ No newline at end of file diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/tensorflow_lite_support/metadata/cc/metadata_extractor.h index c73b09b5..bff308ab 100644 --- a/tensorflow_lite_support/metadata/cc/metadata_extractor.h +++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.h @@ -80,6 +80,10 @@ class ModelMetadataExtractor { tflite::support::StatusOr<absl::string_view> GetAssociatedFile( const std::string& filename) const; + // Gets the model version from the model metadata. An error is returned if + // either the metadata does not exist or no model version is present in it. + tflite::support::StatusOr<std::string> GetModelVersion() const; + // Note: all methods below retrieves metadata of the *first* subgraph as // default. @@ -129,9 +133,6 @@ class ModelMetadataExtractor { // case there is no output process unit or the index is out of range. const tflite::ProcessUnit* GetOutputProcessUnit(int index) const; - // Gets the model version, or "NO_VERSION_INFO" in case there is no version. - std::string GetVersion() const; - // Gets the count of output process units. In particular, 0 is returned when // there is no output process units. int GetOutputProcessUnitsCount() const; diff --git a/tensorflow_lite_support/metadata/metadata_schema.fbs b/tensorflow_lite_support/metadata/metadata_schema.fbs index 8faae0a8..6ce94525 100644 --- a/tensorflow_lite_support/metadata/metadata_schema.fbs +++ b/tensorflow_lite_support/metadata/metadata_schema.fbs @@ -152,6 +152,9 @@ table AssociatedFile { // Leverage this in order to specify e.g multiple label files translated in // different languages. locale:string; + + // Version of the file specified by model creators. + version:string; } // The basic content type for all tensors. |