aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc')
-rw-r--r--tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc428
1 files changed, 428 insertions, 0 deletions
diff --git a/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
new file mode 100644
index 00000000..fa9b05f5
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.cc
@@ -0,0 +1,428 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
+
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow_lite_support/cc/port/status_macros.h"
+
+namespace tflite {
+namespace task {
+namespace vision {
+namespace {
+
+using ::tflite::support::StatusOr;
+
+constexpr int kRgbaChannels = 4;
+constexpr int kRgbChannels = 3;
+constexpr int kGrayChannel = 1;
+
+// Creates a FrameBuffer from raw NV12 buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromNV12RawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ const std::vector<FrameBuffer::Plane> planes_nv12 = {
+ {input, /*stride=*/{dimension.width, kGrayChannel}},
+ {input + dimension.Size(), /*stride=*/{dimension.width, 2}}};
+ return FrameBuffer::Create(planes_nv12, dimension, FrameBuffer::Format::kNV12,
+ orientation, timestamp);
+}
+
+// Creates a FrameBuffer from raw NV21 buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromNV21RawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ FrameBuffer::Plane input_plane = {/*buffer=*/input,
+ /*stride=*/{dimension.width, kGrayChannel}};
+ return FrameBuffer::Create({input_plane}, dimension,
+ FrameBuffer::Format::kNV21, orientation,
+ timestamp);
+}
+
+// Indicates whether the given buffers have the same dimensions.
+bool AreBufferDimsEqual(const FrameBuffer& buffer1,
+ const FrameBuffer& buffer2) {
+ return buffer1.dimension() == buffer2.dimension();
+}
+
+// Indicates whether the given buffers formats are compatible. Same formats are
+// compatible and all YUV family formats (e.g. NV21, NV12, YV12, YV21, etc) are
+// compatible.
+bool AreBufferFormatsCompatible(const FrameBuffer& buffer1,
+ const FrameBuffer& buffer2) {
+ switch (buffer1.format()) {
+ case FrameBuffer::Format::kRGBA:
+ case FrameBuffer::Format::kRGB:
+ return (buffer2.format() == FrameBuffer::Format::kRGBA ||
+ buffer2.format() == FrameBuffer::Format::kRGB);
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return (buffer2.format() == FrameBuffer::Format::kNV12 ||
+ buffer2.format() == FrameBuffer::Format::kNV21 ||
+ buffer2.format() == FrameBuffer::Format::kYV12 ||
+ buffer2.format() == FrameBuffer::Format::kYV21);
+ case FrameBuffer::Format::kGRAY:
+ default:
+ return buffer1.format() == buffer2.format();
+ }
+}
+
+} // namespace
+
+// Miscellaneous Methods
+// -----------------------------------------------------------------
+int GetFrameBufferByteSize(FrameBuffer::Dimension dimension,
+ FrameBuffer::Format format) {
+ switch (format) {
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return /*y plane*/ dimension.Size() +
+ /*uv plane*/ ((static_cast<float>(dimension.width + 1) / 2) *
+ (static_cast<float>(dimension.height + 1) / 2) * 2);
+ case FrameBuffer::Format::kRGB:
+ return dimension.Size() * 3;
+ case FrameBuffer::Format::kRGBA:
+ return dimension.Size() * 4;
+ case FrameBuffer::Format::kGRAY:
+ return dimension.Size();
+ default:
+ return 0;
+ }
+}
+
+StatusOr<int> GetPixelStrides(FrameBuffer::Format format) {
+ switch (format) {
+ case FrameBuffer::Format::kGRAY:
+ return kGrayPixelBytes;
+ case FrameBuffer::Format::kRGB:
+ return kRgbPixelBytes;
+ case FrameBuffer::Format::kRGBA:
+ return kRgbaPixelBytes;
+ default:
+ return absl::InvalidArgumentError(absl::StrFormat(
+ "GetPixelStrides does not support format: %i.", format));
+ }
+}
+
+StatusOr<const uint8*> GetUvRawBuffer(const FrameBuffer& buffer) {
+ if (buffer.format() != FrameBuffer::Format::kNV12 &&
+ buffer.format() != FrameBuffer::Format::kNV21) {
+ return absl::InvalidArgumentError(
+ "Only support getting biplanar UV buffer from NV12/NV21 frame buffer.");
+ }
+ ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data,
+ FrameBuffer::GetYuvDataFromFrameBuffer(buffer));
+ const uint8* uv_buffer = buffer.format() == FrameBuffer::Format::kNV12
+ ? yuv_data.u_buffer
+ : yuv_data.v_buffer;
+ return uv_buffer;
+}
+
+StatusOr<FrameBuffer::Dimension> GetUvPlaneDimension(
+ FrameBuffer::Dimension dimension, FrameBuffer::Format format) {
+ if (dimension.width <= 0 || dimension.height <= 0) {
+ return absl::InvalidArgumentError(
+ absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width,
+ dimension.height));
+ }
+ switch (format) {
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return FrameBuffer::Dimension{(dimension.width + 1) / 2,
+ (dimension.height + 1) / 2};
+ default:
+ return absl::InvalidArgumentError(
+ absl::StrFormat("Input format is not YUV-like: %i.", format));
+ }
+}
+
+FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1) {
+ return {x1 - x0 + 1, y1 - y0 + 1};
+}
+
+// Validation Methods
+// -----------------------------------------------------------------
+
+absl::Status ValidateBufferPlaneMetadata(const FrameBuffer& buffer) {
+ if (buffer.plane_count() < 1) {
+ return absl::InvalidArgumentError(
+ "There must be at least 1 plane specified.");
+ }
+
+ for (int i = 0; i < buffer.plane_count(); i++) {
+ if (buffer.plane(i).stride.row_stride_bytes == 0 ||
+ buffer.plane(i).stride.pixel_stride_bytes == 0) {
+ return absl::InvalidArgumentError("Invalid stride information.");
+ }
+ }
+
+ return absl::OkStatus();
+}
+
+absl::Status ValidateBufferFormat(const FrameBuffer& buffer) {
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kGRAY:
+ case FrameBuffer::Format::kRGB:
+ case FrameBuffer::Format::kRGBA:
+ if (buffer.plane_count() == 1) return absl::OkStatus();
+ return absl::InvalidArgumentError(
+ "Plane count must be 1 for grayscale and RGB[a] buffers.");
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kYV21:
+ case FrameBuffer::Format::kYV12:
+ return absl::OkStatus();
+ default:
+ return absl::InternalError(
+ absl::StrFormat("Unsupported buffer format: %i.", buffer.format()));
+ }
+}
+
+absl::Status ValidateBufferFormats(const FrameBuffer& buffer1,
+ const FrameBuffer& buffer2) {
+ RETURN_IF_ERROR(ValidateBufferFormat(buffer1));
+ RETURN_IF_ERROR(ValidateBufferFormat(buffer2));
+ return absl::OkStatus();
+}
+
+absl::Status ValidateResizeBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer) {
+ bool valid_format = false;
+ switch (buffer.format()) {
+ case FrameBuffer::Format::kGRAY:
+ case FrameBuffer::Format::kRGB:
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ valid_format = (buffer.format() == output_buffer.format());
+ break;
+ case FrameBuffer::Format::kRGBA:
+ valid_format = (output_buffer.format() == FrameBuffer::Format::kRGBA ||
+ output_buffer.format() == FrameBuffer::Format::kRGB);
+ break;
+ default:
+ return absl::InternalError(
+ absl::StrFormat("Unsupported buffer format: %i.", buffer.format()));
+ }
+ if (!valid_format) {
+ return absl::InvalidArgumentError(
+ "Input and output buffer formats must match.");
+ }
+ return ValidateBufferFormats(buffer, output_buffer);
+}
+
+absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer,
+ int angle_deg) {
+ if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
+ return absl::InvalidArgumentError(
+ "Input and output buffer formats must match.");
+ }
+
+ const bool is_dimension_change = (angle_deg / 90) % 2 == 1;
+ const bool are_dimensions_rotated =
+ (buffer.dimension().width == output_buffer.dimension().height) &&
+ (buffer.dimension().height == output_buffer.dimension().width);
+ const bool are_dimensions_equal =
+ buffer.dimension() == output_buffer.dimension();
+
+ if (angle_deg >= 360 || angle_deg <= 0 || angle_deg % 90 != 0) {
+ return absl::InvalidArgumentError(
+ "Rotation angle must be between 0 and 360, in multiples of 90 "
+ "degrees.");
+ } else if ((is_dimension_change && !are_dimensions_rotated) ||
+ (!is_dimension_change && !are_dimensions_equal)) {
+ return absl::InvalidArgumentError(
+ "Output buffer has invalid dimensions for rotation.");
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer, int x0,
+ int y0, int x1, int y1) {
+ if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
+ return absl::InvalidArgumentError(
+ "Input and output buffer formats must match.");
+ }
+
+ bool is_buffer_size_valid =
+ ((x1 < buffer.dimension().width) && y1 < buffer.dimension().height);
+ bool are_points_valid = (x0 >= 0) && (y0 >= 0) && (x1 >= x0) && (y1 >= y0);
+
+ if (!is_buffer_size_valid || !are_points_valid) {
+ return absl::InvalidArgumentError("Invalid crop coordinates.");
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer,
+ const FrameBuffer& output_buffer) {
+ if (!AreBufferFormatsCompatible(buffer, output_buffer)) {
+ return absl::InvalidArgumentError(
+ "Input and output buffer formats must match.");
+ }
+ return AreBufferDimsEqual(buffer, output_buffer)
+ ? absl::OkStatus()
+ : absl::InvalidArgumentError(
+ "Input and output buffers must have the same dimensions.");
+}
+
+absl::Status ValidateConvertFormats(FrameBuffer::Format from_format,
+ FrameBuffer::Format to_format) {
+ if (from_format == to_format) {
+ return absl::InvalidArgumentError("Formats must be different.");
+ }
+
+ switch (from_format) {
+ case FrameBuffer::Format::kGRAY:
+ return absl::InvalidArgumentError(
+ "Grayscale format does not convert to other formats.");
+ case FrameBuffer::Format::kRGB:
+ if (to_format == FrameBuffer::Format::kRGBA) {
+ return absl::InvalidArgumentError(
+ "RGB format does not convert to RGBA");
+ }
+ return absl::OkStatus();
+ case FrameBuffer::Format::kRGBA:
+ case FrameBuffer::Format::kNV12:
+ case FrameBuffer::Format::kNV21:
+ case FrameBuffer::Format::kYV12:
+ case FrameBuffer::Format::kYV21:
+ return absl::OkStatus();
+ default:
+ return absl::InternalError(
+ absl::StrFormat("Unsupported buffer format: %i.", from_format));
+ }
+}
+
+// Creation Methods
+// -----------------------------------------------------------------
+
+// Creates a FrameBuffer from raw RGBA buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromRgbaRawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ FrameBuffer::Plane input_plane = {
+ /*buffer=*/input,
+ /*stride=*/{dimension.width * kRgbaChannels, kRgbaChannels}};
+ return FrameBuffer::Create({input_plane}, dimension,
+ FrameBuffer::Format::kRGBA, orientation,
+ timestamp);
+}
+
+// Creates a FrameBuffer from raw RGB buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromRgbRawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ FrameBuffer::Plane input_plane = {
+ /*buffer=*/input,
+ /*stride=*/{dimension.width * kRgbChannels, kRgbChannels}};
+ return FrameBuffer::Create({input_plane}, dimension,
+ FrameBuffer::Format::kRGB, orientation, timestamp);
+}
+
+// Creates a FrameBuffer from raw grayscale buffer and passing arguments.
+std::unique_ptr<FrameBuffer> CreateFromGrayRawBuffer(
+ const uint8* input, FrameBuffer::Dimension dimension,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ FrameBuffer::Plane input_plane = {/*buffer=*/input,
+ /*stride=*/{dimension.width, kGrayChannel}};
+ return FrameBuffer::Create({input_plane}, dimension,
+ FrameBuffer::Format::kGRAY, orientation,
+ timestamp);
+}
+
+// Creates a FrameBuffer from raw YUV buffer and passing arguments.
+StatusOr<std::unique_ptr<FrameBuffer>> CreateFromYuvRawBuffer(
+ const uint8* y_plane, const uint8* u_plane, const uint8* v_plane,
+ FrameBuffer::Format format, FrameBuffer::Dimension dimension,
+ int row_stride_y, int row_stride_uv, int pixel_stride_uv,
+ FrameBuffer::Orientation orientation, const absl::Time timestamp) {
+ const int pixel_stride_y = 1;
+ std::vector<FrameBuffer::Plane> planes;
+ if (format == FrameBuffer::Format::kNV21 ||
+ format == FrameBuffer::Format::kYV12) {
+ planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}},
+ {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}},
+ {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}};
+ } else if (format == FrameBuffer::Format::kNV12 ||
+ format == FrameBuffer::Format::kYV21) {
+ planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}},
+ {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}},
+ {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}};
+ } else {
+ return absl::InvalidArgumentError(
+ absl::StrFormat("Input format is not YUV-like: %i.", format));
+ }
+ return FrameBuffer::Create(planes, dimension, format, orientation, timestamp);
+}
+
+StatusOr<std::unique_ptr<FrameBuffer>> CreateFromRawBuffer(
+ const uint8* buffer, FrameBuffer::Dimension dimension,
+ const FrameBuffer::Format target_format,
+ FrameBuffer::Orientation orientation, absl::Time timestamp) {
+ switch (target_format) {
+ case FrameBuffer::Format::kNV12:
+ return CreateFromNV12RawBuffer(buffer, dimension, orientation, timestamp);
+ case FrameBuffer::Format::kNV21:
+ return CreateFromNV21RawBuffer(buffer, dimension, orientation, timestamp);
+ case FrameBuffer::Format::kYV12: {
+ ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension,
+ GetUvPlaneDimension(dimension, target_format));
+ return CreateFromYuvRawBuffer(
+ /*y_plane=*/buffer,
+ /*u_plane=*/buffer + dimension.Size() + uv_dimension.Size(),
+ /*v_plane=*/buffer + dimension.Size(), target_format, dimension,
+ /*row_stride_y=*/dimension.width, uv_dimension.width,
+ /*pixel_stride_uv=*/1, orientation, timestamp);
+ }
+ case FrameBuffer::Format::kYV21: {
+ ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension,
+ GetUvPlaneDimension(dimension, target_format));
+ return CreateFromYuvRawBuffer(
+ /*y_plane=*/buffer, /*u_plane=*/buffer + dimension.Size(),
+ /*v_plane=*/buffer + dimension.Size() + uv_dimension.Size(),
+ target_format, dimension, /*row_stride_y=*/dimension.width,
+ uv_dimension.width,
+ /*pixel_stride_uv=*/1, orientation, timestamp);
+ }
+ case FrameBuffer::Format::kRGBA:
+ return CreateFromRgbaRawBuffer(buffer, dimension, orientation, timestamp);
+ case FrameBuffer::Format::kRGB:
+ return CreateFromRgbRawBuffer(buffer, dimension, orientation, timestamp);
+ case FrameBuffer::Format::kGRAY:
+ return CreateFromGrayRawBuffer(buffer, dimension, orientation, timestamp);
+ default:
+
+ return absl::InternalError(
+ absl::StrFormat("Unsupported buffer format: %i.", target_format));
+ }
+}
+
+} // namespace vision
+} // namespace task
+} // namespace tflite