diff options
author | Jared Junyoung Lim <jaredlim@google.com> | 2022-08-22 01:09:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2022-08-22 01:12:47 -0700 |
commit | 013172fadc31628eb3a4a8ca786c6259b57c7ba9 (patch) | |
tree | 9655c64a120bbf0903d54490aa00db47fab66635 | |
parent | df80c783418434e504841c70df6b4054644713f5 (diff) | |
download | tensorflow-013172fadc31628eb3a4a8ca786c6259b57c7ba9.tar.gz |
Expose the error in the benchmark_performance_options.
PiperOrigin-RevId: 469118331
3 files changed, 34 insertions, 18 deletions
diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index e677a8cdbbf..37b2d74adbc 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -184,6 +184,7 @@ cc_library( ":benchmark_model_lib", ":benchmark_params", ":benchmark_utils", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/tools:logging", "@com_google_absl//absl/memory", "//tensorflow/core/util:stats_calculator_portable", diff --git a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc index 923aa4dabe6..91d60350c2e 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc @@ -24,6 +24,7 @@ limitations under the License. #include <utility> #include "tensorflow/core/util/stats_calculator.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #if defined(__ANDROID__) #include "tensorflow/lite/delegates/gpu/delegate.h" @@ -164,14 +165,14 @@ std::vector<Flag> BenchmarkPerformanceOptions::GetFlags() { }; } -bool BenchmarkPerformanceOptions::ParseFlags(int* argc, char** argv) { +TfLiteStatus BenchmarkPerformanceOptions::ParseFlags(int* argc, char** argv) { auto flag_list = GetFlags(); const bool parse_result = Flags::Parse(argc, const_cast<const char**>(argv), flag_list); if (!parse_result) { std::string usage = Flags::Usage(argv[0], flag_list); TFLITE_LOG(ERROR) << usage; - return false; + return kTfLiteError; } // Parse the value of --perf_options_list to find performance options to be @@ -179,14 +180,14 @@ bool BenchmarkPerformanceOptions::ParseFlags(int* argc, char** argv) { return ParsePerfOptions(); } -bool BenchmarkPerformanceOptions::ParsePerfOptions() { +TfLiteStatus BenchmarkPerformanceOptions::ParsePerfOptions() { const auto& perf_options_list = params_.Get<std::string>("perf_options_list"); if (!util::SplitAndParse(perf_options_list, ',', &perf_options_)) { TFLITE_LOG(ERROR) << "Cannot parse --perf_options_list: '" << perf_options_list << "'. Please double-check its value."; perf_options_.clear(); - return false; + return kTfLiteError; } const auto valid_options = GetValidPerfOptions(); @@ -209,16 +210,16 @@ bool BenchmarkPerformanceOptions::ParsePerfOptions() { << perf_options_list << "'. Valid perf options are: [" << valid_options_str << "]"; perf_options_.clear(); - return false; + return kTfLiteError; } if (HasOption("none") && perf_options_.size() > 1) { TFLITE_LOG(ERROR) << "The 'none' option can not be used together with " "other perf options in --perf_options_list!"; perf_options_.clear(); - return false; + return kTfLiteError; } - return true; + return kTfLiteOk; } std::vector<std::string> BenchmarkPerformanceOptions::GetValidPerfOptions() @@ -348,7 +349,7 @@ void BenchmarkPerformanceOptions::CreatePerformanceOptions() { #endif } -void BenchmarkPerformanceOptions::Run() { +TfLiteStatus BenchmarkPerformanceOptions::Run() { CreatePerformanceOptions(); if (params_.Get<bool>("random_shuffle_benchmark_runs")) { @@ -379,26 +380,40 @@ void BenchmarkPerformanceOptions::Run() { single_option_run_->RemoveListeners(num_external_listeners); all_run_stats_->MarkBenchmarkStart(*single_option_run_params_); - single_option_run_->Run(); + if (TfLiteStatus status = single_option_run_->Run(); status != kTfLiteOk) { + TFLITE_LOG(ERROR) << "Error while running a single-option run: " + << status; + return status; + } } all_run_stats_->OutputStats(); + return kTfLiteOk; } -void BenchmarkPerformanceOptions::Run(int argc, char** argv) { +TfLiteStatus BenchmarkPerformanceOptions::Run(int argc, char** argv) { // Parse flags that are supported by this particular binary first. - if (!ParseFlags(&argc, argv)) return; + if (TfLiteStatus status = ParseFlags(&argc, argv); status != kTfLiteOk) { + TFLITE_LOG(ERROR) << "Error while parsing the flags for multi-option runs: " + << status; + return status; + } // Then parse flags for single-option runs to get information like parameters // of the input model etc. - if (single_option_run_->ParseFlags(&argc, argv) != kTfLiteOk) return; + if (TfLiteStatus status = single_option_run_->ParseFlags(&argc, argv); + status != kTfLiteOk) { + TFLITE_LOG(ERROR) + << "Error while parsing the flags for single-option runs: " << status; + return status; + } // Now, the remaining are unrecognized flags and we simply print them out. for (int i = 1; i < argc; ++i) { TFLITE_LOG(WARN) << "WARNING: unrecognized commandline flag: " << argv[i]; } - Run(); + return Run(); } } // namespace benchmark } // namespace tflite diff --git a/tensorflow/lite/tools/benchmark/benchmark_performance_options.h b/tensorflow/lite/tools/benchmark/benchmark_performance_options.h index 4aeaa8382a7..c13161812f0 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_performance_options.h +++ b/tensorflow/lite/tools/benchmark/benchmark_performance_options.h @@ -20,7 +20,7 @@ limitations under the License. #include <string> #include <vector> -#include "absl/memory/memory.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/tools/benchmark/benchmark_model.h" #include "tensorflow/lite/tools/benchmark/benchmark_params.h" @@ -85,8 +85,8 @@ class BenchmarkPerformanceOptions { virtual ~BenchmarkPerformanceOptions() {} // Just run the benchmark just w/ default parameter values. - void Run(); - void Run(int argc, char** argv); + TfLiteStatus Run(); + TfLiteStatus Run(int argc, char** argv); protected: static BenchmarkParams DefaultParams(); @@ -97,10 +97,10 @@ class BenchmarkPerformanceOptions { // Unparsable flags will remain in 'argv' in the original order and 'argc' // will be updated accordingly. - bool ParseFlags(int* argc, char** argv); + TfLiteStatus ParseFlags(int* argc, char** argv); virtual std::vector<Flag> GetFlags(); - bool ParsePerfOptions(); + TfLiteStatus ParsePerfOptions(); virtual std::vector<std::string> GetValidPerfOptions() const; bool HasOption(const std::string& option) const; |