aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJared Junyoung Lim <jaredlim@google.com>2022-08-22 01:09:10 -0700
committerTensorFlower Gardener <gardener@tensorflow.org>2022-08-22 01:12:47 -0700
commit013172fadc31628eb3a4a8ca786c6259b57c7ba9 (patch)
tree9655c64a120bbf0903d54490aa00db47fab66635
parentdf80c783418434e504841c70df6b4054644713f5 (diff)
downloadtensorflow-013172fadc31628eb3a4a8ca786c6259b57c7ba9.tar.gz
Expose the error in the benchmark_performance_options.
PiperOrigin-RevId: 469118331
-rw-r--r--tensorflow/lite/tools/benchmark/BUILD1
-rw-r--r--tensorflow/lite/tools/benchmark/benchmark_performance_options.cc41
-rw-r--r--tensorflow/lite/tools/benchmark/benchmark_performance_options.h10
3 files changed, 34 insertions, 18 deletions
diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD
index e677a8cdbbf..37b2d74adbc 100644
--- a/tensorflow/lite/tools/benchmark/BUILD
+++ b/tensorflow/lite/tools/benchmark/BUILD
@@ -184,6 +184,7 @@ cc_library(
":benchmark_model_lib",
":benchmark_params",
":benchmark_utils",
+ "//tensorflow/lite/c:c_api_types",
"//tensorflow/lite/tools:logging",
"@com_google_absl//absl/memory",
"//tensorflow/core/util:stats_calculator_portable",
diff --git a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc
index 923aa4dabe6..91d60350c2e 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_performance_options.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/util/stats_calculator.h"
+#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#if defined(__ANDROID__)
#include "tensorflow/lite/delegates/gpu/delegate.h"
@@ -164,14 +165,14 @@ std::vector<Flag> BenchmarkPerformanceOptions::GetFlags() {
};
}
-bool BenchmarkPerformanceOptions::ParseFlags(int* argc, char** argv) {
+TfLiteStatus BenchmarkPerformanceOptions::ParseFlags(int* argc, char** argv) {
auto flag_list = GetFlags();
const bool parse_result =
Flags::Parse(argc, const_cast<const char**>(argv), flag_list);
if (!parse_result) {
std::string usage = Flags::Usage(argv[0], flag_list);
TFLITE_LOG(ERROR) << usage;
- return false;
+ return kTfLiteError;
}
// Parse the value of --perf_options_list to find performance options to be
@@ -179,14 +180,14 @@ bool BenchmarkPerformanceOptions::ParseFlags(int* argc, char** argv) {
return ParsePerfOptions();
}
-bool BenchmarkPerformanceOptions::ParsePerfOptions() {
+TfLiteStatus BenchmarkPerformanceOptions::ParsePerfOptions() {
const auto& perf_options_list = params_.Get<std::string>("perf_options_list");
if (!util::SplitAndParse(perf_options_list, ',', &perf_options_)) {
TFLITE_LOG(ERROR) << "Cannot parse --perf_options_list: '"
<< perf_options_list
<< "'. Please double-check its value.";
perf_options_.clear();
- return false;
+ return kTfLiteError;
}
const auto valid_options = GetValidPerfOptions();
@@ -209,16 +210,16 @@ bool BenchmarkPerformanceOptions::ParsePerfOptions() {
<< perf_options_list << "'. Valid perf options are: ["
<< valid_options_str << "]";
perf_options_.clear();
- return false;
+ return kTfLiteError;
}
if (HasOption("none") && perf_options_.size() > 1) {
TFLITE_LOG(ERROR) << "The 'none' option can not be used together with "
"other perf options in --perf_options_list!";
perf_options_.clear();
- return false;
+ return kTfLiteError;
}
- return true;
+ return kTfLiteOk;
}
std::vector<std::string> BenchmarkPerformanceOptions::GetValidPerfOptions()
@@ -348,7 +349,7 @@ void BenchmarkPerformanceOptions::CreatePerformanceOptions() {
#endif
}
-void BenchmarkPerformanceOptions::Run() {
+TfLiteStatus BenchmarkPerformanceOptions::Run() {
CreatePerformanceOptions();
if (params_.Get<bool>("random_shuffle_benchmark_runs")) {
@@ -379,26 +380,40 @@ void BenchmarkPerformanceOptions::Run() {
single_option_run_->RemoveListeners(num_external_listeners);
all_run_stats_->MarkBenchmarkStart(*single_option_run_params_);
- single_option_run_->Run();
+ if (TfLiteStatus status = single_option_run_->Run(); status != kTfLiteOk) {
+ TFLITE_LOG(ERROR) << "Error while running a single-option run: "
+ << status;
+ return status;
+ }
}
all_run_stats_->OutputStats();
+ return kTfLiteOk;
}
-void BenchmarkPerformanceOptions::Run(int argc, char** argv) {
+TfLiteStatus BenchmarkPerformanceOptions::Run(int argc, char** argv) {
// Parse flags that are supported by this particular binary first.
- if (!ParseFlags(&argc, argv)) return;
+ if (TfLiteStatus status = ParseFlags(&argc, argv); status != kTfLiteOk) {
+ TFLITE_LOG(ERROR) << "Error while parsing the flags for multi-option runs: "
+ << status;
+ return status;
+ }
// Then parse flags for single-option runs to get information like parameters
// of the input model etc.
- if (single_option_run_->ParseFlags(&argc, argv) != kTfLiteOk) return;
+ if (TfLiteStatus status = single_option_run_->ParseFlags(&argc, argv);
+ status != kTfLiteOk) {
+ TFLITE_LOG(ERROR)
+ << "Error while parsing the flags for single-option runs: " << status;
+ return status;
+ }
// Now, the remaining are unrecognized flags and we simply print them out.
for (int i = 1; i < argc; ++i) {
TFLITE_LOG(WARN) << "WARNING: unrecognized commandline flag: " << argv[i];
}
- Run();
+ return Run();
}
} // namespace benchmark
} // namespace tflite
diff --git a/tensorflow/lite/tools/benchmark/benchmark_performance_options.h b/tensorflow/lite/tools/benchmark/benchmark_performance_options.h
index 4aeaa8382a7..c13161812f0 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_performance_options.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_performance_options.h
@@ -20,7 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "absl/memory/memory.h"
+#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/tools/benchmark/benchmark_model.h"
#include "tensorflow/lite/tools/benchmark/benchmark_params.h"
@@ -85,8 +85,8 @@ class BenchmarkPerformanceOptions {
virtual ~BenchmarkPerformanceOptions() {}
// Just run the benchmark just w/ default parameter values.
- void Run();
- void Run(int argc, char** argv);
+ TfLiteStatus Run();
+ TfLiteStatus Run(int argc, char** argv);
protected:
static BenchmarkParams DefaultParams();
@@ -97,10 +97,10 @@ class BenchmarkPerformanceOptions {
// Unparsable flags will remain in 'argv' in the original order and 'argc'
// will be updated accordingly.
- bool ParseFlags(int* argc, char** argv);
+ TfLiteStatus ParseFlags(int* argc, char** argv);
virtual std::vector<Flag> GetFlags();
- bool ParsePerfOptions();
+ TfLiteStatus ParsePerfOptions();
virtual std::vector<std::string> GetValidPerfOptions() const;
bool HasOption(const std::string& option) const;