diff options
Diffstat (limited to 'tensorflow_lite_support/cc/task/vision/core/label_map_item.cc')
-rw-r--r-- | tensorflow_lite_support/cc/task/vision/core/label_map_item.cc | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc b/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc new file mode 100644 index 00000000..75b1fc60 --- /dev/null +++ b/tensorflow_lite_support/cc/task/vision/core/label_map_item.cc @@ -0,0 +1,128 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "tensorflow_lite_support/cc/common.h" + +namespace tflite { +namespace task { +namespace vision { + +using ::absl::StatusCode; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( + absl::string_view labels_file, absl::string_view display_names_file) { + if (labels_file.empty()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Expected non-empty labels file.", + TfLiteSupportStatus::kInvalidArgumentError); + } + std::vector<absl::string_view> labels = absl::StrSplit(labels_file, '\n'); + // In most cases, there is an empty line (i.e. newline character) at the end + // of the file that needs to be ignored. In such a situation, StrSplit() will + // produce a vector with an empty string as final element. Also note that in + // case `labels_file` is entirely empty, StrSplit() will produce a vector with + // one single empty substring, so there's no out-of-range risk here. + if (labels[labels.size() - 1].empty()) { + labels.pop_back(); + } + + std::vector<LabelMapItem> label_map_items; + label_map_items.reserve(labels.size()); + for (int i = 0; i < labels.size(); ++i) { + label_map_items.emplace_back(LabelMapItem{.name = std::string(labels[i])}); + } + + if (!display_names_file.empty()) { + std::vector<std::string> display_names = + absl::StrSplit(display_names_file, '\n'); + // In most cases, there is an empty line (i.e. newline character) at the end + // of the file that needs to be ignored. See above. + if (display_names[display_names.size() - 1].empty()) { + display_names.pop_back(); + } + if (display_names.size() != labels.size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Mismatch between number of labels (%d) and display names (%d).", + labels.size(), display_names.size()), + TfLiteSupportStatus::kMetadataNumLabelsMismatchError); + } + for (int i = 0; i < display_names.size(); ++i) { + label_map_items[i].display_name = display_names[i]; + } + } + return label_map_items; +} + +absl::Status LabelHierarchy::InitializeFromLabelMap( + std::vector<LabelMapItem> label_map_items) { + parents_map_.clear(); + for (const LabelMapItem& label : label_map_items) { + for (const std::string& child_name : label.child_name) { + parents_map_[child_name].insert(label.name); + } + } + if (parents_map_.empty()) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Input labelmap is not hierarchical: there " + "is no parent-child relationship.", + TfLiteSupportStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +bool LabelHierarchy::HaveAncestorDescendantRelationship( + const std::string& ancestor_name, + const std::string& descendant_name) const { + absl::flat_hash_set<std::string> ancestors; + GetAncestors(descendant_name, &ancestors); + return ancestors.contains(ancestor_name); +} + +absl::flat_hash_set<std::string> LabelHierarchy::GetParents( + const std::string& name) const { + absl::flat_hash_set<std::string> parents; + auto it = parents_map_.find(name); + if (it != parents_map_.end()) { + for (const std::string& parent_name : it->second) { + parents.insert(parent_name); + } + } + return parents; +} + +void LabelHierarchy::GetAncestors( + const std::string& name, + absl::flat_hash_set<std::string>* ancestors) const { + const absl::flat_hash_set<std::string> parents = GetParents(name); + for (const std::string& parent_name : parents) { + auto it = ancestors->insert(parent_name); + if (it.second) { + GetAncestors(parent_name, ancestors); + } + } +} + +} // namespace vision +} // namespace task +} // namespace tflite |