aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/cc/task/vision/core/label_map_item.cc')
-rw-r--r--tensorflow_lite_support/cc/task/vision/core/label_map_item.cc128
1 files changed, 128 insertions, 0 deletions
diff --git a/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc b/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
new file mode 100644
index 00000000..75b1fc60
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc
@@ -0,0 +1,128 @@
+
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow_lite_support/cc/common.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+using ::absl::StatusCode;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+
+StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
+ absl::string_view labels_file, absl::string_view display_names_file) {
+ if (labels_file.empty()) {
+ return CreateStatusWithPayload(StatusCode::kInvalidArgument,
+ "Expected non-empty labels file.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ std::vector<absl::string_view> labels = absl::StrSplit(labels_file, '\n');
+ // In most cases, there is an empty line (i.e. newline character) at the end
+ // of the file that needs to be ignored. In such a situation, StrSplit() will
+ // produce a vector with an empty string as final element. Also note that in
+ // case `labels_file` is entirely empty, StrSplit() will produce a vector with
+ // one single empty substring, so there's no out-of-range risk here.
+ if (labels[labels.size() - 1].empty()) {
+ labels.pop_back();
+ }
+
+ std::vector<LabelMapItem> label_map_items;
+ label_map_items.reserve(labels.size());
+ for (int i = 0; i < labels.size(); ++i) {
+ label_map_items.emplace_back(LabelMapItem{.name = std::string(labels[i])});
+ }
+
+ if (!display_names_file.empty()) {
+ std::vector<std::string> display_names =
+ absl::StrSplit(display_names_file, '\n');
+ // In most cases, there is an empty line (i.e. newline character) at the end
+ // of the file that needs to be ignored. See above.
+ if (display_names[display_names.size() - 1].empty()) {
+ display_names.pop_back();
+ }
+ if (display_names.size() != labels.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Mismatch between number of labels (%d) and display names (%d).",
+ labels.size(), display_names.size()),
+ TfLiteSupportStatus::kMetadataNumLabelsMismatchError);
+ }
+ for (int i = 0; i < display_names.size(); ++i) {
+ label_map_items[i].display_name = display_names[i];
+ }
+ }
+ return label_map_items;
+}
+
+absl::Status LabelHierarchy::InitializeFromLabelMap(
+ std::vector<LabelMapItem> label_map_items) {
+ parents_map_.clear();
+ for (const LabelMapItem& label : label_map_items) {
+ for (const std::string& child_name : label.child_name) {
+ parents_map_[child_name].insert(label.name);
+ }
+ }
+ if (parents_map_.empty()) {
+ return CreateStatusWithPayload(StatusCode::kInvalidArgument,
+ "Input labelmap is not hierarchical: there "
+ "is no parent-child relationship.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ return absl::OkStatus();
+}
+
+bool LabelHierarchy::HaveAncestorDescendantRelationship(
+ const std::string& ancestor_name,
+ const std::string& descendant_name) const {
+ absl::flat_hash_set<std::string> ancestors;
+ GetAncestors(descendant_name, &ancestors);
+ return ancestors.contains(ancestor_name);
+}
+
+absl::flat_hash_set<std::string> LabelHierarchy::GetParents(
+ const std::string& name) const {
+ absl::flat_hash_set<std::string> parents;
+ auto it = parents_map_.find(name);
+ if (it != parents_map_.end()) {
+ for (const std::string& parent_name : it->second) {
+ parents.insert(parent_name);
+ }
+ }
+ return parents;
+}
+
+void LabelHierarchy::GetAncestors(
+ const std::string& name,
+ absl::flat_hash_set<std::string>* ancestors) const {
+ const absl::flat_hash_set<std::string> parents = GetParents(name);
+ for (const std::string& parent_name : parents) {
+ auto it = ancestors->insert(parent_name);
+ if (it.second) {
+ GetAncestors(parent_name, ancestors);
+ }
+ }
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite