aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/cc/task/vision/image_classifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/cc/task/vision/image_classifier.cc')
-rw-r--r--tensorflow_lite_support/cc/task/vision/image_classifier.cc572
1 files changed, 572 insertions, 0 deletions
diff --git a/tensorflow_lite_support/cc/task/vision/image_classifier.cc b/tensorflow_lite_support/cc/task/vision/image_classifier.cc
new file mode 100644
index 00000000..378797b4
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/image_classifier.cc
@@ -0,0 +1,572 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "flatbuffers/flatbuffers.h" // from @flatbuffers
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow_lite_support/cc/common.h"
+#include "tensorflow_lite_support/cc/port/integral_types.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
+#include "tensorflow_lite_support/cc/task/core/task_utils.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+#include "tensorflow_lite_support/cc/task/vision/core/label_map_item.h"
+#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+namespace {
+
+using ::absl::StatusCode;
+using ::tflite::metadata::ModelMetadataExtractor;
+using ::tflite::support::CreateStatusWithPayload;
+using ::tflite::support::StatusOr;
+using ::tflite::support::TfLiteSupportStatus;
+using ::tflite::task::core::AssertAndReturnTypedTensor;
+using ::tflite::task::core::TaskAPIFactory;
+using ::tflite::task::core::TfLiteEngine;
+
+// Default score value used as a fallback for classes that (1) have no score
+// calibration data or (2) have a very low confident uncalibrated score, i.e.
+// lower than the `min_uncalibrated_score` threshold.
+//
+// (1) This happens when the ScoreCalibration does not cover all the classes
+// listed in the label map. This can be used to enforce the blacklisting of
+// given classes so that they are never returned.
+//
+// (2) This is an optional threshold provided part of the calibration data. It
+// is used to mitigate false alarms on some classes.
+//
+// In both cases, a class that gets assigned a score of -1 is never returned as
+// it gets discarded by the `score_threshold` check (see post-processing logic).
+constexpr float kDefaultCalibratedScore = -1.0f;
+
+// Calibrated scores should be in the [0, 1] range, otherwise an error is
+// returned at post-processing time.
+constexpr float kMinCalibratedScore = 0.0f;
+constexpr float kMaxCalibratedScore = 1.0f;
+
+} // namespace
+
+/* static */
+StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::CreateFromOptions(
+ const ImageClassifierOptions& options,
+ std::unique_ptr<tflite::OpResolver> resolver) {
+ RETURN_IF_ERROR(SanityCheckOptions(options));
+
+ // Copy options to ensure the ExternalFile outlives the constructed object.
+ auto options_copy = absl::make_unique<ImageClassifierOptions>(options);
+
+ ASSIGN_OR_RETURN(auto image_classifier,
+ TaskAPIFactory::CreateFromExternalFileProto<ImageClassifier>(
+ &options_copy->model_file_with_metadata(),
+ std::move(resolver), options_copy->num_threads()));
+
+ RETURN_IF_ERROR(image_classifier->Init(std::move(options_copy)));
+
+ return image_classifier;
+}
+
+/* static */
+absl::Status ImageClassifier::SanityCheckOptions(
+ const ImageClassifierOptions& options) {
+ if (!options.has_model_file_with_metadata()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Missing mandatory `model_file_with_metadata` field",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.max_results() == 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "Invalid `max_results` option: value must be != 0",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.score_threshold() < 0 || options.score_threshold() >= 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "`score_threshold` out of range: %f. Valid range is [0,1[.",
+ options.score_threshold()),
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.class_name_whitelist_size() > 0 &&
+ options.class_name_blacklist_size() > 0) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "`class_name_whitelist` and `class_name_blacklist` are mutually "
+ "exclusive options.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ if (options.num_threads() == 0 || options.num_threads() < -1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ "`num_threads` must be greater than 0 or equal to -1.",
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::Init(
+ std::unique_ptr<ImageClassifierOptions> options) {
+ // Set options.
+ options_ = std::move(options);
+
+ // Perform pre-initialization actions (by default, sets the process engine for
+ // image pre-processing to kLibyuv as a sane default).
+ RETURN_IF_ERROR(PreInit());
+
+ // Sanity check and set inputs and outputs.
+ RETURN_IF_ERROR(CheckAndSetInputs());
+ RETURN_IF_ERROR(CheckAndSetOutputs());
+
+ // Initialize class whitelisting/blacklisting, if any.
+ RETURN_IF_ERROR(CheckAndSetClassNameSet());
+
+ // Perform final initialization (by default, initialize score calibration
+ // parameters, if any).
+ RETURN_IF_ERROR(PostInit());
+
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::PreInit() {
+ SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::PostInit() { return InitScoreCalibrations(); }
+
+absl::Status ImageClassifier::CheckAndSetOutputs() {
+ num_outputs_ = TfLiteEngine::OutputCount(engine_->interpreter());
+
+ // Perform sanity checks and extract metadata.
+ const ModelMetadataExtractor* metadata_extractor =
+ engine_->metadata_extractor();
+
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
+ output_tensor_metadata = metadata_extractor->GetOutputTensorMetadata();
+
+ // Loop over output tensors metadata, if any.
+ // Note: models with no output tensor metadata at all are supported.
+ if (output_tensor_metadata != nullptr) {
+ int num_output_tensors = output_tensor_metadata->size();
+
+ if (num_outputs_ != num_output_tensors) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Mismatch between number of output tensors (%d) and "
+ "output tensors "
+ "metadata (%d).",
+ num_outputs_, num_output_tensors),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+
+ for (int i = 0; i < num_output_tensors; ++i) {
+ const tflite::TensorMetadata* output_tensor =
+ output_tensor_metadata->Get(i);
+
+ ASSIGN_OR_RETURN(
+ ClassificationHead head,
+ BuildClassificationHead(*metadata_extractor, *output_tensor,
+ options_->display_names_locale()));
+
+ classification_heads_.emplace_back(std::move(head));
+ }
+ }
+
+ // If classifier heads are not set, build default ones based on model
+ // introspection. This happens if a model with partial or no metadata was
+ // provided through the `model_file_with_metadata` options field.
+ if (classification_heads_.empty()) {
+ classification_heads_.reserve(num_outputs_);
+ for (int output_index = 0; output_index < num_outputs_; ++output_index) {
+ classification_heads_.emplace_back(ClassificationHead{});
+ }
+ }
+
+ if (num_outputs_ != classification_heads_.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Got %d classifier head(s), expected %d according to "
+ "the label map.",
+ num_outputs_, classification_heads_.size()),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+
+ int num_quantized_outputs = 0;
+ for (int i = 0; i < num_outputs_; ++i) {
+ const TfLiteTensor* output_tensor =
+ TfLiteEngine::GetOutput(engine_->interpreter(), i);
+ const int num_dimensions = output_tensor->dims->size;
+ if (num_dimensions == 4) {
+ if (output_tensor->dims->data[1] != 1 ||
+ output_tensor->dims->data[2] != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Unexpected WxH sizes for output index %d: got "
+ "%dx%d, expected 1x1.",
+ i, output_tensor->dims->data[2],
+ output_tensor->dims->data[1]),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ } else if (num_dimensions != 2) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Unexpected number of dimensions for output index %d: got %dD, "
+ "expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, "
+ "H=1).",
+ i, num_dimensions),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ if (output_tensor->dims->data[0] != 1) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("The output array is expected to have a batch size "
+ "of 1. Got %d for output index %d.",
+ output_tensor->dims->data[0], i),
+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
+ }
+ int num_classes = output_tensor->dims->data[num_dimensions - 1];
+ // If label map is not set, build a default one based on model
+ // introspection. This happens if a model with partial or no metadata was
+ // provided through the `model_file_with_metadata` options field.
+ if (classification_heads_[i].label_map_items.empty()) {
+ classification_heads_[i].label_map_items.reserve(num_classes);
+ for (int class_index = 0; class_index < num_classes; ++class_index) {
+ classification_heads_[i].label_map_items.emplace_back(LabelMapItem{});
+ }
+ }
+ int num_label_map_items = classification_heads_[i].label_map_items.size();
+ if (num_classes != num_label_map_items) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Got %d class(es) for output index %d, expected %d "
+ "according to the label map.",
+ output_tensor->dims->data[num_dimensions - 1], i,
+ num_label_map_items),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+ if (output_tensor->type == kTfLiteUInt8) {
+ num_quantized_outputs++;
+ } else if (output_tensor->type != kTfLiteFloat32) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Type mismatch for output tensor %s. Requested one "
+ "of these types: "
+ "kTfLiteUint8/kTfLiteFloat32, got %s.",
+ output_tensor->name,
+ TfLiteTypeGetName(output_tensor->type)),
+ TfLiteSupportStatus::kInvalidOutputTensorTypeError);
+ }
+ }
+
+ if (num_quantized_outputs > 0 && num_quantized_outputs != num_outputs_) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Got %d quantized output(s), expected %d (i.e. all "
+ "provided outputs must be quantized).",
+ num_quantized_outputs, num_outputs_),
+ TfLiteSupportStatus::kInvalidOutputTensorTypeError);
+ }
+ has_uint8_outputs_ = (num_quantized_outputs > 0);
+
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::CheckAndSetClassNameSet() {
+ // Exit early if no blacklist/whitelist.
+ if (options_->class_name_blacklist_size() == 0 &&
+ options_->class_name_whitelist_size() == 0) {
+ return absl::OkStatus();
+ }
+
+ // Before processing class names whitelist or blacklist from the input options
+ // create a set with _all_ known class names from the label map(s).
+ absl::flat_hash_set<std::string> all_class_names;
+ int head_index = 0;
+ for (const auto& head : classification_heads_) {
+ absl::flat_hash_set<std::string> head_class_names;
+ for (const auto& item : head.label_map_items) {
+ if (!item.name.empty()) {
+ head_class_names.insert(item.name);
+ }
+ }
+ if (head_class_names.empty()) {
+ std::string name = head.name;
+ if (name.empty()) {
+ name = absl::StrFormat("#%d", head_index);
+ }
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Using `class_name_whitelist` or `class_name_blacklist` "
+ "requires labels to be present but none was found for "
+ "classification head: %s",
+ name),
+ TfLiteSupportStatus::kMetadataMissingLabelsError);
+ }
+ all_class_names.insert(head_class_names.begin(), head_class_names.end());
+ head_index++;
+ }
+
+ class_name_set_.is_whitelist = options_->class_name_whitelist_size() > 0;
+ const auto& class_names = class_name_set_.is_whitelist
+ ? options_->class_name_whitelist()
+ : options_->class_name_blacklist();
+
+ // Note: duplicate or unknown classes are just ignored.
+ class_name_set_.values.clear();
+ for (const auto& class_name : class_names) {
+ if (!all_class_names.contains(class_name)) {
+ continue;
+ }
+ class_name_set_.values.insert(class_name);
+ }
+
+ if (class_name_set_.values.empty()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat(
+ "Invalid class names specified via `class_name_%s`: none match "
+ "with model labels.",
+ class_name_set_.is_whitelist ? "whitelist" : "blacklist"),
+ TfLiteSupportStatus::kInvalidArgumentError);
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status ImageClassifier::InitScoreCalibrations() {
+ score_calibrations_.clear();
+ score_calibrations_.resize(classification_heads_.size());
+
+ for (int i = 0; i < classification_heads_.size(); ++i) {
+ if (!classification_heads_[i].calibration_params.has_value()) {
+ continue;
+ }
+
+ // Use a specific default score instead of the one specified by default in
+ // cc/task/vision/utils/score_calibration.h. See `kDefaultCalibratedScore`
+ // documentation for more details.
+ classification_heads_[i].calibration_params->default_score =
+ kDefaultCalibratedScore;
+
+ score_calibrations_[i] = absl::make_unique<ScoreCalibration>();
+ if (score_calibrations_[i] == nullptr) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal, "Could not create score calibration object.");
+ }
+
+ RETURN_IF_ERROR(score_calibrations_[i]->InitializeFromParameters(
+ classification_heads_[i].calibration_params.value()));
+ }
+
+ return absl::OkStatus();
+}
+
+StatusOr<ClassificationResult> ImageClassifier::Classify(
+ const FrameBuffer& frame_buffer) {
+ BoundingBox roi;
+ roi.set_width(frame_buffer.dimension().width);
+ roi.set_height(frame_buffer.dimension().height);
+ return Classify(frame_buffer, roi);
+}
+
+StatusOr<ClassificationResult> ImageClassifier::Classify(
+ const FrameBuffer& frame_buffer, const BoundingBox& roi) {
+ return InferWithFallback(frame_buffer, roi);
+}
+
+StatusOr<ClassificationResult> ImageClassifier::Postprocess(
+ const std::vector<const TfLiteTensor*>& output_tensors,
+ const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
+ if (output_tensors.size() != num_outputs_) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("Expected %d output tensors, found %d", num_outputs_,
+ output_tensors.size()));
+ }
+
+ ClassificationResult result;
+ std::vector<std::pair<int, float>> score_pairs;
+
+ for (int i = 0; i < num_outputs_; ++i) {
+ auto* classifications = result.add_classifications();
+ classifications->set_head_index(i);
+
+ const auto& head = classification_heads_[i];
+ score_pairs.clear();
+ score_pairs.reserve(head.label_map_items.size());
+
+ const TfLiteTensor* output_tensor = output_tensors[i];
+ if (has_uint8_outputs_) {
+ const uint8* output_data =
+ AssertAndReturnTypedTensor<uint8>(output_tensor);
+ for (int j = 0; j < head.label_map_items.size(); ++j) {
+ score_pairs.emplace_back(j, output_tensor->params.scale *
+ (static_cast<int>(output_data[j]) -
+ output_tensor->params.zero_point));
+ }
+ } else {
+ const float* output_data =
+ AssertAndReturnTypedTensor<float>(output_tensor);
+ for (int j = 0; j < head.label_map_items.size(); ++j) {
+ score_pairs.emplace_back(j, output_data[j]);
+ }
+ }
+
+ // Optional score calibration.
+ if (score_calibrations_[i] != nullptr) {
+ for (auto& score_pair : score_pairs) {
+ const std::string& class_name =
+ head.label_map_items[score_pair.first].name;
+ score_pair.second = score_calibrations_[i]->ComputeCalibratedScore(
+ class_name, score_pair.second);
+ if (score_pair.second > kMaxCalibratedScore) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("calibrated score is too high: got %f, expected "
+ "%f as maximum.",
+ score_pair.second, kMaxCalibratedScore));
+ }
+ if (score_pair.second != kDefaultCalibratedScore &&
+ score_pair.second < kMinCalibratedScore) {
+ return CreateStatusWithPayload(
+ StatusCode::kInternal,
+ absl::StrFormat("calibrated score is too low: got %f, expected "
+ "%f as minimum.",
+ score_pair.second, kMinCalibratedScore));
+ }
+ }
+ }
+
+ int num_results =
+ options_->max_results() >= 0
+ ? std::min(static_cast<int>(head.label_map_items.size()),
+ options_->max_results())
+ : head.label_map_items.size();
+ float score_threshold = options_->has_score_threshold()
+ ? options_->score_threshold()
+ : head.score_threshold;
+
+ if (class_name_set_.values.empty()) {
+ // Partially sort in descending order (higher score is better).
+ absl::c_partial_sort(
+ score_pairs, score_pairs.begin() + num_results,
+ [](const std::pair<int, float>& a, const std::pair<int, float>& b) {
+ return a.second > b.second;
+ });
+
+ for (int j = 0; j < num_results; ++j) {
+ float score = score_pairs[j].second;
+ if (score < score_threshold) {
+ break;
+ }
+ auto* cl = classifications->add_classes();
+ cl->set_index(score_pairs[j].first);
+ cl->set_score(score);
+ }
+ } else {
+ // Sort in descending order (higher score is better).
+ absl::c_sort(score_pairs, [](const std::pair<int, float>& a,
+ const std::pair<int, float>& b) {
+ return a.second > b.second;
+ });
+
+ for (int j = 0; j < head.label_map_items.size(); ++j) {
+ float score = score_pairs[j].second;
+ if (score < score_threshold ||
+ classifications->classes_size() >= num_results) {
+ break;
+ }
+
+ const int class_index = score_pairs[j].first;
+ const std::string& class_name = head.label_map_items[class_index].name;
+
+ bool class_name_found = class_name_set_.values.contains(class_name);
+
+ if ((!class_name_found && class_name_set_.is_whitelist) ||
+ (class_name_found && !class_name_set_.is_whitelist)) {
+ continue;
+ }
+
+ auto* cl = classifications->add_classes();
+ cl->set_index(class_index);
+ cl->set_score(score);
+ }
+ }
+ }
+
+ RETURN_IF_ERROR(FillResultsFromLabelMaps(&result));
+
+ return result;
+}
+
+absl::Status ImageClassifier::FillResultsFromLabelMaps(
+ ClassificationResult* result) {
+ for (int i = 0; i < result->classifications_size(); ++i) {
+ Classifications* classifications = result->mutable_classifications(i);
+ int head_index = classifications->head_index();
+ if (head_index < 0 || head_index >= classification_heads_.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Invalid head index (%d) with respect to total "
+ "number of classification heads (%d).",
+ head_index, classification_heads_.size()),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+ const std::vector<LabelMapItem>& label_map_items =
+ classification_heads_[head_index].label_map_items;
+ for (int j = 0; j < classifications->classes_size(); ++j) {
+ Class* current_class = classifications->mutable_classes(j);
+ int current_class_index = current_class->index();
+ if (current_class_index < 0 ||
+ current_class_index >= label_map_items.size()) {
+ return CreateStatusWithPayload(
+ StatusCode::kInvalidArgument,
+ absl::StrFormat("Invalid class index (%d) with respect to label "
+ "map size (%d) for head #%d.",
+ current_class_index, label_map_items.size(),
+ head_index),
+ TfLiteSupportStatus::kMetadataInconsistencyError);
+ }
+ const std::string& name = label_map_items[current_class_index].name;
+ if (!name.empty()) {
+ current_class->set_class_name(name);
+ }
+ const std::string& display_name =
+ label_map_items[current_class_index].display_name;
+ if (!display_name.empty()) {
+ current_class->set_display_name(display_name);
+ }
+ }
+ }
+ return absl::OkStatus();
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite