aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-06-16 12:26:15 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-06-16 12:26:15 +0000
commitf4198d228105ee9c6e0feaa6d7a7fe7df46546b9 (patch)
tree1fccb2b92c9b7a466967dd16e16a710ae3b436f8
parent45970cd1fc37cba4d13c800a5d52514289ed6539 (diff)
parent456bf98dcc1a98a825c0d34cb374ba8bbf73083a (diff)
downloadtflite-support-f4198d228105ee9c6e0feaa6d7a7fe7df46546b9.tar.gz
Snap for 8734275 from 456bf98dcc1a98a825c0d34cb374ba8bbf73083a to mainline-resolv-releaseaml_res_330910000aml_res_330810000
Change-Id: I3f23ae31440e27ea13b5f2f6da228c74d24c1acc
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc19
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h8
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java21
-rw-r--r--tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java12
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc13
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc9
-rw-r--r--tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h4
-rw-r--r--tensorflow_lite_support/metadata/cc/metadata_extractor.cc26
-rw-r--r--tensorflow_lite_support/metadata/cc/metadata_extractor.h7
-rw-r--r--tensorflow_lite_support/metadata/metadata_schema.fbs3
10 files changed, 95 insertions, 27 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
index 5fafd17c..a75fe0ff 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc
@@ -63,6 +63,7 @@ using ::tflite::task::core::PopulateTensor;
namespace {
constexpr int kRegexTokenizerInputTensorIndex = 0;
constexpr int kRegexTokenizerProcessUnitIndex = 0;
+constexpr char kNoVersionInfo[] = "NO_VERSION_INFO";
StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
@@ -169,6 +170,11 @@ absl::Status NLClassifier::TrySetLabelFromMetadata(
labels_vector_ =
absl::make_unique<std::vector<std::string>>(LoadVocabFromBuffer(
label_buffer.value().data(), label_buffer.value().size()));
+ if (associated_file->version() == nullptr) {
+ labels_version_ = kNoVersionInfo;
+ } else {
+ labels_version_ = associated_file->version()->str();
+ }
return absl::OkStatus();
} else {
return CreateStatusWithPayload(
@@ -184,8 +190,17 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) {
return Infer(text).value();
}
-std::string NLClassifier::GetVersion() const {
- return GetMetadataExtractor()->GetVersion();
+std::string NLClassifier::GetModelVersion() const {
+ tflite::support::StatusOr<std::string> model_version =
+ GetMetadataExtractor()->GetModelVersion();
+ if (model_version.ok()) {
+ return model_version.value();
+ }
+ return kNoVersionInfo;
+}
+
+std::string NLClassifier::GetLabelsVersion() const {
+ return labels_version_;
}
absl::Status NLClassifier::Preprocess(
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
index 05189dda..013b7d53 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h
@@ -113,7 +113,10 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
std::vector<core::Category> Classify(const std::string& text);
// Gets the model version, or "NO_VERSION_INFO" in case there is no version.
- std::string GetVersion() const;
+ std::string GetModelVersion() const;
+
+ // Gets the labels version, or "NO_VERSION_INFO" in case there is no version.
+ std::string GetLabelsVersion() const;
protected:
static constexpr int kOutputTensorIndex = 0;
@@ -173,6 +176,9 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
// labels vector initialized from output tensor's associated file, if one
// exists.
std::unique_ptr<std::vector<std::string>> labels_vector_;
+ // labels version assigned from output tensor's associated file metadata,
+ // if one exists.
+ std::string labels_version_;
std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_;
};
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
index 0906cd03..65cac78e 100644
--- a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java
@@ -123,12 +123,23 @@ public class BertNLClassifier extends BaseTaskApi {
}
/**
- * Gets the model version from the model metadata.
+ * Gets the model version from the model metadata,
+ * or "NO_VERSION_INFO" in case there is no version.
*
* @return The model version.
*/
- public String getVersion() {
- return getVersionNative(getNativeHandle());
+ public String getModelVersion() {
+ return getModelVersionNative(getNativeHandle());
+ }
+
+ /**
+ * Gets the labels version from the model metadata,
+ * or "NO_VERSION_INFO" in case there is no version.
+ *
+ * @return The labels version.
+ */
+ public String getLabelsVersion() {
+ return getLabelsVersionNative(getNativeHandle());
}
private static native long initJniWithByteBuffer(ByteBuffer modelBuffer);
@@ -137,7 +148,9 @@ public class BertNLClassifier extends BaseTaskApi {
private static native List<Category> classifyNative(long nativeHandle, String text);
- private static native String getVersionNative(long nativeHandle);
+ private static native String getModelVersionNative(long nativeHandle);
+
+ private static native String getLabelsVersionNative(long nativeHandle);
@Override
protected void deinit(long nativeHandle) {
diff --git a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
index a4c2f525..f0667d77 100644
--- a/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
+++ b/tensorflow_lite_support/java/src/javatests/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifierTest.java
@@ -68,11 +68,19 @@ public class BertNLClassifierTest {
}
@Test
- public void getVersion_succeedsWithVersionInMetadata() throws IOException {
+ public void getModelVersion_succeedsWithVersionInMetadata() throws IOException {
BertNLClassifier classifier = BertNLClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), MODEL_FILE);
- assertThat(classifier.getVersion()).isEqualTo("v1");
+ assertThat(classifier.getModelVersion()).isEqualTo("v1");
+ }
+
+ @Test
+ public void getLabelsVersion_succeedsWithNoVersionInMetadata() throws IOException {
+ BertNLClassifier classifier = BertNLClassifier.createFromFile(
+ ApplicationProvider.getApplicationContext(), MODEL_FILE);
+
+ assertThat(classifier.getLabelsVersion()).isEqualTo("NO_VERSION_INFO");
}
private void verifyResults(BertNLClassifier classifier) {
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
index 0866764e..aef82408 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert_nlclassifier/bert_nl_classifier_jni.cc
@@ -27,7 +27,8 @@ using ::tflite::support::utils::kInvalidPointer;
using ::tflite::support::utils::ThrowException;
using ::tflite::task::text::nlclassifier::BertNLClassifier;
using ::tflite::task::text::nlclassifier::RunClassifier;
-using ::tflite::task::text::nlclassifier::GetVersionNative;
+using ::tflite::task::text::nlclassifier::GetModelVersionNative;
+using ::tflite::task::text::nlclassifier::GetLabelsVersionNative;
extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni(
@@ -73,9 +74,15 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative(
}
extern "C" JNIEXPORT jstring JNICALL
-Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_getVersionNative(
+Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_getModelVersionNative(
JNIEnv* env, jclass clazz, jlong native_handle) {
- return GetVersionNative(env, native_handle);
+ return GetModelVersionNative(env, native_handle);
+}
+
+extern "C" JNIEXPORT jstring JNICALL
+Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_getLabelsVersionNative(
+ JNIEnv* env, jclass clazz, jlong native_handle) {
+ return GetLabelsVersionNative(env, native_handle);
}
} // namespace
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
index e6040674..6b002f71 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.cc
@@ -50,9 +50,14 @@ jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text) {
});
}
-jstring GetVersionNative(JNIEnv* env, jlong native_handle) {
+jstring GetModelVersionNative(JNIEnv* env, jlong native_handle) {
auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle);
- return env->NewStringUTF(nl_classifier->GetVersion().c_str());
+ return env->NewStringUTF(nl_classifier->GetModelVersion().c_str());
+}
+
+jstring GetLabelsVersionNative(JNIEnv* env, jlong native_handle) {
+ auto* nl_classifier = reinterpret_cast<NLClassifier*>(native_handle);
+ return env->NewStringUTF(nl_classifier->GetLabelsVersion().c_str());
}
} // namespace nlclassifier
diff --git a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h
index 2c8fbc07..c21eda70 100644
--- a/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h
+++ b/tensorflow_lite_support/java/src/native/task/text/nlclassifier/nl_classifier_jni_utils.h
@@ -25,7 +25,9 @@ namespace nlclassifier {
jobject RunClassifier(JNIEnv* env, jlong native_handle, jstring text);
-jstring GetVersionNative(JNIEnv* env, jlong native_handle);
+jstring GetModelVersionNative(JNIEnv* env, jlong native_handle);
+
+jstring GetLabelsVersionNative(JNIEnv* env, jlong native_handle);
} // namespace nlclassifier
} // namespace text
diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
index 93263a17..c2d85bc0 100644
--- a/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
+++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.cc
@@ -35,7 +35,6 @@ namespace metadata {
namespace {
constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
-constexpr char kNoVersionInfo[] = "NO_VERSION_INFO";
using ::absl::StatusCode;
using ::flatbuffers::Offset;
@@ -291,6 +290,23 @@ ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const {
return it->second;
}
+tflite::support::StatusOr<std::string>
+ModelMetadataExtractor::GetModelVersion() const {
+ if (model_metadata_ == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kFailedPrecondition,
+ "No model metadata",
+ TfLiteSupportStatus::kMetadataNotFoundError);
+ }
+ if (model_metadata_->version() == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kNotFound,
+ "No version in model metadata",
+ TfLiteSupportStatus::kMetadataNotFoundError);
+ }
+ return model_metadata_->version()->str();
+}
+
const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
ModelMetadataExtractor::GetInputTensorMetadata() const {
if (model_metadata_ == nullptr ||
@@ -381,13 +397,5 @@ int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
return output_process_units == nullptr ? 0 : output_process_units->size();
}
-std::string ModelMetadataExtractor::GetVersion() const {
- if (model_metadata_ == nullptr ||
- model_metadata_->version() == nullptr) {
- return kNoVersionInfo;
- }
- return model_metadata_->version()->str();
-}
-
} // namespace metadata
} // namespace tflite \ No newline at end of file
diff --git a/tensorflow_lite_support/metadata/cc/metadata_extractor.h b/tensorflow_lite_support/metadata/cc/metadata_extractor.h
index c73b09b5..bff308ab 100644
--- a/tensorflow_lite_support/metadata/cc/metadata_extractor.h
+++ b/tensorflow_lite_support/metadata/cc/metadata_extractor.h
@@ -80,6 +80,10 @@ class ModelMetadataExtractor {
tflite::support::StatusOr<absl::string_view> GetAssociatedFile(
const std::string& filename) const;
+ // Gets the model version from the model metadata. An error is returned if
+ // either the metadata does not exist or no model version is present in it.
+ tflite::support::StatusOr<std::string> GetModelVersion() const;
+
// Note: all methods below retrieves metadata of the *first* subgraph as
// default.
@@ -129,9 +133,6 @@ class ModelMetadataExtractor {
// case there is no output process unit or the index is out of range.
const tflite::ProcessUnit* GetOutputProcessUnit(int index) const;
- // Gets the model version, or "NO_VERSION_INFO" in case there is no version.
- std::string GetVersion() const;
-
// Gets the count of output process units. In particular, 0 is returned when
// there is no output process units.
int GetOutputProcessUnitsCount() const;
diff --git a/tensorflow_lite_support/metadata/metadata_schema.fbs b/tensorflow_lite_support/metadata/metadata_schema.fbs
index 8faae0a8..6ce94525 100644
--- a/tensorflow_lite_support/metadata/metadata_schema.fbs
+++ b/tensorflow_lite_support/metadata/metadata_schema.fbs
@@ -152,6 +152,9 @@ table AssociatedFile {
// Leverage this in order to specify e.g multiple label files translated in
// different languages.
locale:string;
+
+ // Version of the file specified by model creators.
+ version:string;
}
// The basic content type for all tensors.