aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h')
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h93
1 files changed, 93 insertions, 0 deletions
diff --git a/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h
new file mode 100644
index 00000000..536eed4d
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h
@@ -0,0 +1,93 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
+
+#include <array>
+
+#include "absl/types/optional.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow_lite_support/cc/port/statusor.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+
+// Parameters used for input image normalization when input tensor has
+// kTfLiteFloat32 type.
+//
+// Exactly 1 or 3 values are expected for `mean_values` and `std_values`. In
+// case 1 value only is specified, it is used for all channels. E.g. for a RGB
+// image, the normalization is done as follow:
+//
+// (R - mean_values[0]) / std_values[0]
+// (G - mean_values[1]) / std_values[1]
+// (B - mean_values[2]) / std_values[2]
+//
+// `num_values` keeps track of how many values have been provided, which should
+// be 1 or 3 (see above). In particular, single-channel grayscale images expect
+// only 1 value.
+struct NormalizationOptions {
+ std::array<float, 3> mean_values;
+ std::array<float, 3> std_values;
+ int num_values;
+};
+
+// Parameters related to the expected tensor specifications when the tensor
+// represents an image.
+//
+// E.g. input tensor specifications expected by the model at Invoke() time. In
+// such a case, and before running inference with the TF Lite interpreter, the
+// caller must use these values and perform image preprocessing and/or
+// normalization so as to fill the actual input tensor appropriately.
+struct ImageTensorSpecs {
+ // Expected image dimensions, e.g. image_width=224, image_height=224.
+ int image_width;
+ int image_height;
+ // Expected color space, e.g. color_space=RGB.
+ tflite::ColorSpaceType color_space;
+ // Expected input tensor type, e.g. if tensor_type=kTfLiteFloat32 the caller
+ // should usually perform some normalization to convert the uint8 pixels into
+ // floats (see NormalizationOptions in TF Lite Metadata for more details).
+ TfLiteType tensor_type;
+ // Optional normalization parameters read from TF Lite Metadata. Those are
+ // mandatory when tensor_type=kTfLiteFloat32 in order to convert the input
+ // image data into the expected range of floating point values, an error is
+ // returned otherwise (see sanity checks below). They should be ignored for
+ // other tensor input types, e.g. kTfLiteUInt8.
+ absl::optional<NormalizationOptions> normalization_options;
+};
+
+// Performs sanity checks on the expected input tensor including consistency
+// checks against model metadata, if any. For now, a single RGB input with BHWD
+// layout, where B = 1 and D = 3, is expected. Returns the corresponding input
+// specifications if they pass, or an error otherwise (too many input tensors,
+// etc).
+// Note: both interpreter and metadata extractor *must* be successfully
+// initialized before calling this function by means of (respectively):
+// - `tflite::InterpreterBuilder`,
+// - `tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer`.
+tflite::support::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
+ const tflite::task::core::TfLiteEngine::Interpreter& interpreter,
+ const tflite::metadata::ModelMetadataExtractor& metadata_extractor);
+
+} // namespace vision
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_