aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/cc/task/core/tflite_engine.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/cc/task/core/tflite_engine.h')
-rw-r--r--tensorflow_lite_support/cc/task/core/tflite_engine.h245
1 files changed, 245 insertions, 0 deletions
diff --git a/tensorflow_lite_support/cc/task/core/tflite_engine.h b/tensorflow_lite_support/cc/task/core/tflite_engine.h
new file mode 100644
index 00000000..30f239da
--- /dev/null
+++ b/tensorflow_lite_support/cc/task/core/tflite_engine.h
@@ -0,0 +1,245 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
+#define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
+
+#include <sys/mman.h>
+
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/core/api/op_resolver.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow_lite_support/cc/port/tflite_wrapper.h"
+#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
+#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
+#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
+
+// If compiled with -DTFLITE_USE_C_API, this file will use the TF Lite C API
+// rather than the TF Lite C++ API.
+// TODO(b/168025296): eliminate the '#if TFLITE_USE_C_API' directives here and
+// elsewhere and instead use the C API unconditionally, once we have a suitable
+// replacement for the features of tflite::support::TfLiteInterpreterWrapper.
+#if TFLITE_USE_C_API
+#include "tensorflow/lite/c/c_api.h"
+#include "tensorflow/lite/core/api/verifier.h"
+#include "tensorflow/lite/tools/verifier.h"
+#else
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/model.h"
+#endif
+
+namespace tflite {
+namespace task {
+namespace core {
+
+// TfLiteEngine encapsulates logic for TFLite model initialization, inference
+// and error reporting.
+class TfLiteEngine {
+ public:
+ // Types.
+ using InterpreterWrapper = tflite::support::TfLiteInterpreterWrapper;
+#if TFLITE_USE_C_API
+ using Model = struct TfLiteModel;
+ using Interpreter = struct TfLiteInterpreter;
+ using ModelDeleter = void (*)(Model*);
+ using InterpreterDeleter = InterpreterWrapper::InterpreterDeleter;
+#else
+ using Model = tflite::FlatBufferModel;
+ using Interpreter = tflite::Interpreter;
+ using ModelDeleter = std::default_delete<Model>;
+ using InterpreterDeleter = std::default_delete<Interpreter>;
+#endif
+
+ // Constructors.
+ explicit TfLiteEngine(
+ std::unique_ptr<tflite::OpResolver> resolver =
+ absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
+ // Model is neither copyable nor movable.
+ TfLiteEngine(const TfLiteEngine&) = delete;
+ TfLiteEngine& operator=(const TfLiteEngine&) = delete;
+
+ // Accessors.
+ static int32_t InputCount(const Interpreter* interpreter) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetInputTensorCount(interpreter);
+#else
+ return interpreter->inputs().size();
+#endif
+ }
+ static int32_t OutputCount(const Interpreter* interpreter) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetOutputTensorCount(interpreter);
+#else
+ return interpreter->outputs().size();
+#endif
+ }
+ static TfLiteTensor* GetInput(Interpreter* interpreter, int index) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetInputTensor(interpreter, index);
+#else
+ return interpreter->tensor(interpreter->inputs()[index]);
+#endif
+ }
+ // Same as above, but const.
+ static const TfLiteTensor* GetInput(const Interpreter* interpreter,
+ int index) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetInputTensor(interpreter, index);
+#else
+ return interpreter->tensor(interpreter->inputs()[index]);
+#endif
+ }
+ static TfLiteTensor* GetOutput(Interpreter* interpreter, int index) {
+#if TFLITE_USE_C_API
+ // We need a const_cast here, because the TF Lite C API only has a non-const
+ // version of GetOutputTensor (in part because C doesn't support overloading
+ // on const).
+ return const_cast<TfLiteTensor*>(
+ TfLiteInterpreterGetOutputTensor(interpreter, index));
+#else
+ return interpreter->tensor(interpreter->outputs()[index]);
+#endif
+ }
+ // Same as above, but const.
+ static const TfLiteTensor* GetOutput(const Interpreter* interpreter,
+ int index) {
+#if TFLITE_USE_C_API
+ return TfLiteInterpreterGetOutputTensor(interpreter, index);
+#else
+ return interpreter->tensor(interpreter->outputs()[index]);
+#endif
+ }
+
+ std::vector<TfLiteTensor*> GetInputs();
+ std::vector<const TfLiteTensor*> GetOutputs();
+
+ const Model* model() const { return model_.get(); }
+ Interpreter* interpreter() { return interpreter_.get(); }
+ const Interpreter* interpreter() const { return interpreter_.get(); }
+ InterpreterWrapper* interpreter_wrapper() { return &interpreter_; }
+ const tflite::metadata::ModelMetadataExtractor* metadata_extractor() const {
+ return model_metadata_extractor_.get();
+ }
+
+ // Builds the TF Lite FlatBufferModel (model_) from the raw FlatBuffer data
+ // whose ownership remains with the caller, and which must outlive the current
+ // object. This performs extra verification on the input data using
+ // tflite::Verify.
+ absl::Status BuildModelFromFlatBuffer(const char* buffer_data,
+ size_t buffer_size);
+
+ // Builds the TF Lite model from a given file.
+ absl::Status BuildModelFromFile(const std::string& file_name);
+
+ // Builds the TF Lite model from a given file descriptor using mmap(2).
+ absl::Status BuildModelFromFileDescriptor(int file_descriptor);
+
+ // Builds the TFLite model from the provided ExternalFile proto, which must
+ // outlive the current object.
+ absl::Status BuildModelFromExternalFileProto(
+ const ExternalFile* external_file);
+
+ // Initializes interpreter with encapsulated model.
+ // Note: setting num_threads to -1 has for effect to let TFLite runtime set
+ // the value.
+ absl::Status InitInterpreter(int num_threads = 1);
+
+ // Same as above, but allows specifying `compute_settings` for acceleration.
+ absl::Status InitInterpreter(
+ const tflite::proto::ComputeSettings& compute_settings,
+ int num_threads = 1);
+
+ // Cancels the on-going `Invoke()` call if any and if possible. This method
+ // can be called from a different thread than the one where `Invoke()` is
+ // running.
+ void Cancel() {
+#if TFLITE_USE_C_API
+ // NOP.
+#else
+ interpreter_.Cancel();
+#endif
+ }
+
+ protected:
+ // TF Lite's DefaultErrorReporter() outputs to stderr. This one captures the
+ // error into a string so that it can be used to complement tensorflow::Status
+ // error messages.
+ struct ErrorReporter : public tflite::ErrorReporter {
+ // Last error message captured by this error reporter.
+ char error_message[256];
+ int Report(const char* format, va_list args) override;
+ };
+ // Custom error reporter capturing low-level TF Lite error messages.
+ ErrorReporter error_reporter_;
+
+ private:
+ // Direct wrapper around tflite::TfLiteVerifier which checks the integrity of
+ // the FlatBuffer data provided as input.
+ class Verifier : public tflite::TfLiteVerifier {
+ public:
+ explicit Verifier(const tflite::OpResolver* op_resolver)
+ : op_resolver_(op_resolver) {}
+ bool Verify(const char* data, int length,
+ tflite::ErrorReporter* reporter) override;
+ // The OpResolver to be used to build the TF Lite interpreter.
+ const tflite::OpResolver* op_resolver_;
+ };
+
+ // Verifies that the supplied buffer refers to a valid flatbuffer model,
+ // and that it uses only operators that are supported by the OpResolver
+ // that was passed to the TfLiteEngine constructor, and then builds
+ // the model from the buffer and stores it in 'model_'.
+ void VerifyAndBuildModelFromBuffer(const char* buffer_data,
+ size_t buffer_size);
+
+ // Gets the buffer from the file handler; verifies and builds the model
+ // from the buffer; if successful, sets 'model_metadata_extractor_' to be
+ // a TF Lite Metadata extractor for the model; and calculates an appropriate
+ // return Status,
+ absl::Status InitializeFromModelFileHandler();
+
+ // TF Lite model and interpreter for actual inference.
+ std::unique_ptr<Model, ModelDeleter> model_;
+
+ // Interpreter wrapper built from the model.
+ InterpreterWrapper interpreter_;
+
+ // TFLite Metadata extractor built from the model.
+ std::unique_ptr<tflite::metadata::ModelMetadataExtractor>
+ model_metadata_extractor_;
+
+ // Mechanism used by TF Lite to map Ops referenced in the FlatBuffer model to
+ // actual implementation. Defaults to TF Lite BuiltinOpResolver.
+ std::unique_ptr<tflite::OpResolver> resolver_;
+
+ // Extra verifier for FlatBuffer input data.
+ Verifier verifier_;
+
+ // ExternalFile and corresponding ExternalFileHandler for models loaded from
+ // disk or file descriptor.
+ ExternalFile external_file_;
+ std::unique_ptr<ExternalFileHandler> model_file_handler_;
+};
+
+} // namespace core
+} // namespace task
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_