summaryrefslogtreecommitdiff
path: root/thermal/virtualtemp_estimator
diff options
context:
space:
mode:
Diffstat (limited to 'thermal/virtualtemp_estimator')
-rw-r--r--thermal/virtualtemp_estimator/virtualtemp_estimator.cpp158
-rw-r--r--thermal/virtualtemp_estimator/virtualtemp_estimator.h58
-rw-r--r--thermal/virtualtemp_estimator/virtualtemp_estimator_data.h76
-rw-r--r--thermal/virtualtemp_estimator/virtualtemp_estimator_test.cpp416
4 files changed, 708 insertions, 0 deletions
diff --git a/thermal/virtualtemp_estimator/virtualtemp_estimator.cpp b/thermal/virtualtemp_estimator/virtualtemp_estimator.cpp
new file mode 100644
index 00000000..2dc2185c
--- /dev/null
+++ b/thermal/virtualtemp_estimator/virtualtemp_estimator.cpp
@@ -0,0 +1,158 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "virtualtemp_estimator.h"
+
+#include <android-base/logging.h>
+#include <dlfcn.h>
+
+#include <vector>
+
+namespace thermal {
+namespace vtestimator {
+
+void VirtualTempEstimator::LoadTFLiteWrapper() {
+ if (!data_) {
+ LOG(ERROR) << "data_ is nullptr during LoadTFLiteWrapper";
+ return;
+ }
+
+ std::unique_lock<std::mutex> lock(data_->tflite_methods.mutex_);
+
+ void *mLibHandle = dlopen("/vendor/lib64/libthermal_tflite_wrapper.so", 0);
+ if (mLibHandle == nullptr) {
+ LOG(ERROR) << "Could not load libthermal_tflite_wrapper library with error: " << dlerror();
+ return;
+ }
+
+ data_->tflite_methods.create =
+ reinterpret_cast<tflitewrapper_create>(dlsym(mLibHandle, "Create"));
+ if (!data_->tflite_methods.create) {
+ LOG(ERROR) << "Could not link and cast tflitewrapper_create with error: " << dlerror();
+ }
+
+ data_->tflite_methods.init = reinterpret_cast<tflitewrapper_init>(dlsym(mLibHandle, "Init"));
+ if (!data_->tflite_methods.init) {
+ LOG(ERROR) << "Could not link and cast tflitewrapper_init with error: " << dlerror();
+ }
+
+ data_->tflite_methods.invoke =
+ reinterpret_cast<tflitewrapper_invoke>(dlsym(mLibHandle, "Invoke"));
+ if (!data_->tflite_methods.invoke) {
+ LOG(ERROR) << "Could not link and cast tflitewrapper_invoke with error: " << dlerror();
+ }
+
+ data_->tflite_methods.destroy =
+ reinterpret_cast<tflitewrapper_destroy>(dlsym(mLibHandle, "Destroy"));
+ if (!data_->tflite_methods.destroy) {
+ LOG(ERROR) << "Could not link and cast tflitewrapper_destroy with error: " << dlerror();
+ }
+}
+
+VirtualTempEstimator::VirtualTempEstimator(size_t num_input_samples) {
+ data_ = std::make_unique<VirtualTempEstimatorTFLiteData>(num_input_samples);
+ LoadTFLiteWrapper();
+}
+
+VirtualTempEstimator::~VirtualTempEstimator() {
+ LOG(INFO) << "VirtualTempEstimator destructor";
+}
+
+VtEstimatorStatus VirtualTempEstimator::Initialize(const char *model_path) {
+ LOG(INFO) << "Initialize VirtualTempEstimator\n";
+
+ if (!data_) {
+ LOG(ERROR) << "data_ is nullptr during Initialize\n";
+ return kVtEstimatorInitFailed;
+ }
+
+ std::unique_lock<std::mutex> lock(data_->tflite_methods.mutex_);
+
+ if (!model_path) {
+ LOG(ERROR) << "Invalid model_path:" << model_path;
+ return kVtEstimatorInvalidArgs;
+ }
+
+ if (!data_->input_buffer || !data_->input_buffer_size) {
+ LOG(ERROR) << "Invalid data_ members " << model_path
+ << " input_buffer: " << data_->input_buffer
+ << " input_buffer_size: " << data_->input_buffer_size;
+ return kVtEstimatorInitFailed;
+ }
+
+ if (!data_->tflite_methods.create || !data_->tflite_methods.init ||
+ !data_->tflite_methods.invoke || !data_->tflite_methods.destroy) {
+ LOG(ERROR) << "Invalid tflite methods";
+ return kVtEstimatorInitFailed;
+ }
+
+ data_->tflite_wrapper = data_->tflite_methods.create(kNumInputTensors, kNumOutputTensors);
+ if (!data_->tflite_wrapper) {
+ LOG(ERROR) << "Failed to create tflite wrapper";
+ return kVtEstimatorInitFailed;
+ }
+
+ int ret = data_->tflite_methods.init(data_->tflite_wrapper, model_path);
+ if (ret) {
+ LOG(ERROR) << "Failed to Init tflite_wrapper for " << model_path << " (ret: )" << ret
+ << ")";
+ return kVtEstimatorInitFailed;
+ }
+
+ data_->is_initialized = true;
+ data_->model_path = model_path;
+
+ LOG(INFO) << "Successfully initialized VirtualTempEstimator for " << model_path;
+ return kVtEstimatorOk;
+}
+
+VtEstimatorStatus VirtualTempEstimator::Estimate(const std::vector<float> &thermistors,
+ float *output) {
+ if (!data_) {
+ LOG(ERROR) << "data_ is nullptr during Estimate\n";
+ return kVtEstimatorInitFailed;
+ }
+
+ std::unique_lock<std::mutex> lock(data_->tflite_methods.mutex_);
+
+ if (!data_->is_initialized) {
+ LOG(ERROR) << "data_ not initialized for " << data_->model_path;
+ return kVtEstimatorInitFailed;
+ }
+
+ if ((thermistors.size() != data_->input_buffer_size) || (!output)) {
+ LOG(ERROR) << "Invalid args for " << data_->model_path
+ << " thermistors.size(): " << thermistors.size()
+ << " input_buffer_size: " << data_->input_buffer_size << " output: " << output;
+ return kVtEstimatorInvalidArgs;
+ }
+
+ // copy input data into input tensors
+ for (size_t i = 0; i < data_->input_buffer_size; ++i) {
+ data_->input_buffer[i] = thermistors[i];
+ }
+
+ int ret = data_->tflite_methods.invoke(data_->tflite_wrapper, data_->input_buffer,
+ data_->input_buffer_size, output, 1);
+ if (ret) {
+ LOG(ERROR) << "Failed to Invoke for " << data_->model_path << " (ret: " << ret << ")";
+ return kVtEstimatorInvokeFailed;
+ }
+
+ return kVtEstimatorOk;
+}
+
+} // namespace vtestimator
+} // namespace thermal
diff --git a/thermal/virtualtemp_estimator/virtualtemp_estimator.h b/thermal/virtualtemp_estimator/virtualtemp_estimator.h
new file mode 100644
index 00000000..0ae37f9e
--- /dev/null
+++ b/thermal/virtualtemp_estimator/virtualtemp_estimator.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <vector>
+
+#include "virtualtemp_estimator_data.h"
+
+namespace thermal {
+namespace vtestimator {
+
+enum VtEstimatorStatus {
+ kVtEstimatorOk = 0,
+ kVtEstimatorInvalidArgs = 1,
+ kVtEstimatorInitFailed = 2,
+ kVtEstimatorInvokeFailed = 3,
+ kVtEstimatorUnSupported = 4,
+};
+
+// Class to estimate virtual temperature based on a model
+class VirtualTempEstimator {
+ public:
+ // Implicit copy-move headers.
+ VirtualTempEstimator(const VirtualTempEstimator &) = delete;
+ VirtualTempEstimator(VirtualTempEstimator &&) = default;
+ VirtualTempEstimator &operator=(const VirtualTempEstimator &) = delete;
+ VirtualTempEstimator &operator=(VirtualTempEstimator &&) = default;
+
+ VirtualTempEstimator(size_t num_input_samples);
+ ~VirtualTempEstimator();
+
+ // Initializes the model provided by model_path.
+ VtEstimatorStatus Initialize(const char *model_path);
+
+ // Performs the inference on the loaded VT model.
+ // Output of the inference is returned in output argument
+ VtEstimatorStatus Estimate(const std::vector<float> &thermistors, float *output);
+
+ private:
+ void LoadTFLiteWrapper();
+ std::unique_ptr<VirtualTempEstimatorTFLiteData> data_;
+};
+
+} // namespace vtestimator
+} // namespace thermal
diff --git a/thermal/virtualtemp_estimator/virtualtemp_estimator_data.h b/thermal/virtualtemp_estimator/virtualtemp_estimator_data.h
new file mode 100644
index 00000000..935c753b
--- /dev/null
+++ b/thermal/virtualtemp_estimator/virtualtemp_estimator_data.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cstddef>
+#include <mutex>
+#include <string>
+
+#pragma once
+
+namespace thermal {
+namespace vtestimator {
+
+// Current version only supports single input/output tensors
+constexpr int kNumInputTensors = 1;
+constexpr int kNumOutputTensors = 1;
+
+typedef void *(*tflitewrapper_create)(int num_input_tensors, int num_output_tensors);
+typedef bool (*tflitewrapper_init)(void *handle, const char *model_path);
+typedef bool (*tflitewrapper_invoke)(void *handle, float *input_samples, int num_input_samples,
+ float *output_samples, int num_output_samples);
+typedef void (*tflitewrapper_destroy)(void *handle);
+
+struct TFLiteWrapperMethods {
+ tflitewrapper_create create;
+ tflitewrapper_init init;
+ tflitewrapper_invoke invoke;
+ tflitewrapper_destroy destroy;
+ mutable std::mutex mutex_;
+};
+
+struct VirtualTempEstimatorTFLiteData {
+ VirtualTempEstimatorTFLiteData(size_t num_input_samples) {
+ input_buffer = new float[num_input_samples];
+ input_buffer_size = num_input_samples;
+ is_initialized = false;
+ tflite_wrapper = nullptr;
+
+ tflite_methods.create = nullptr;
+ tflite_methods.init = nullptr;
+ tflite_methods.invoke = nullptr;
+ tflite_methods.destroy = nullptr;
+ }
+
+ void *tflite_wrapper;
+ float *input_buffer;
+ size_t input_buffer_size;
+ std::string model_path;
+ TFLiteWrapperMethods tflite_methods;
+ bool is_initialized;
+
+ ~VirtualTempEstimatorTFLiteData() {
+ if (tflite_wrapper && tflite_methods.destroy) {
+ tflite_methods.destroy(tflite_wrapper);
+ }
+
+ if (input_buffer) {
+ delete input_buffer;
+ }
+ }
+};
+
+} // namespace vtestimator
+} // namespace thermal
diff --git a/thermal/virtualtemp_estimator/virtualtemp_estimator_test.cpp b/thermal/virtualtemp_estimator/virtualtemp_estimator_test.cpp
new file mode 100644
index 00000000..fde99977
--- /dev/null
+++ b/thermal/virtualtemp_estimator/virtualtemp_estimator_test.cpp
@@ -0,0 +1,416 @@
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/**
+ *@file virtualtemp_estimator_test.cc
+ * Test application to verify virtualtemp estimator
+ *
+ */
+// Test application to run and verify virtualtemp estimator interface unit tests
+
+#include "virtualtemp_estimator.h"
+
+#include <android-base/file.h>
+#include <android-base/logging.h>
+#include <android-base/parsedouble.h>
+#include <android-base/properties.h>
+#include <android-base/strings.h>
+#include <cutils/properties.h>
+#include <cutils/trace.h>
+#include <json/reader.h>
+#include <json/value.h>
+#include <json/writer.h>
+#include <log/log.h>
+#include <malloc.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/time.h>
+
+#include <climits>
+#include <fstream>
+#include <iostream>
+
+constexpr std::string_view kDefaultModel("/vendor/etc/vt_estimation_model.tflite");
+constexpr std::string_view kConfigProperty("vendor.thermal.config");
+constexpr std::string_view kConfigDefaultFileName("thermal_info_config.json");
+constexpr int kmillion = 1000000;
+constexpr int klog_interval_usec = 10 * kmillion;
+
+static inline unsigned long get_elapsed_time_usec(struct timeval start, struct timeval end) {
+ unsigned long elapsed_time = (end.tv_sec - start.tv_sec) * kmillion;
+ elapsed_time += (end.tv_usec - start.tv_usec);
+
+ return elapsed_time;
+}
+
+static std::vector<std::string> get_input_combination(std::string_view thermal_config_path) {
+ std::vector<std::string> result;
+ std::string json_doc;
+ if (!android::base::ReadFileToString(thermal_config_path.data(), &json_doc)) {
+ std::cout << "Failed to read JSON config from " << thermal_config_path.data();
+ return result;
+ }
+
+ Json::Value root;
+ Json::CharReaderBuilder reader_builder;
+ std::unique_ptr<Json::CharReader> reader(reader_builder.newCharReader());
+ std::string errorMessage;
+
+ if (!reader->parse(&*json_doc.begin(), &*json_doc.end(), &root, &errorMessage)) {
+ std::cout << "Failed to parse JSON config: " << errorMessage;
+ return result;
+ }
+
+ Json::Value sensors = root["Sensors"];
+ std::cout << "Sensors size: " << sensors.size() << std::endl;
+
+ for (Json::Value::ArrayIndex i = 0; i < sensors.size(); ++i) {
+ const std::string &name = sensors[i]["Name"].asString();
+ if (name == "VIRTUAL-SKIN-MODEL") {
+ Json::Value values = sensors[i]["Combination"];
+ if (values.size() == 0) {
+ return result;
+ }
+
+ std::cout << "Combination : [";
+ for (Json::Value::ArrayIndex j = 0; j < values.size(); ++j) {
+ result.push_back(values[j].asString());
+ std::cout << result.back() << ", ";
+ }
+ std::cout << "]" << std::endl;
+ }
+ }
+
+ return result;
+}
+
+static int run_random_input_inference(std::string_view model_path,
+ std::string_view thermal_config_path, int min_inference_count,
+ int inference_delay_sec) {
+ float output;
+ unsigned long prev_log_time = 0;
+ thermal::vtestimator::VtEstimatorStatus ret;
+ std::vector<std::string> input_combination = get_input_combination(thermal_config_path.data());
+ int input_size = input_combination.size();
+ thermal::vtestimator::VirtualTempEstimator vt_estimator_(input_size);
+
+ std::cout << "Initialize estimator\n";
+ ret = vt_estimator_.Initialize(model_path.data());
+ if (ret != thermal::vtestimator::kVtEstimatorOk) {
+ std::cout << "Failed to Initialize estimator (ret: " << ret << ")\n";
+ return -1;
+ }
+
+ struct timeval start_loop_time;
+ int inference_count = 0;
+ unsigned long max_inference_time = 0, min_inference_time = ULONG_MAX;
+ unsigned long sum_inference_time = 0;
+ float avg_inference_time = 0;
+ std::vector<unsigned long> inference_times;
+
+ gettimeofday(&start_loop_time, nullptr);
+ do {
+ struct timeval begin, end;
+ std::vector<float> thermistors;
+
+ // preparing random inputs with starting temperature between 20C to 40C
+ int r = 20 + std::rand() % 20;
+ for (int i = 0; i < input_size; ++i) {
+ thermistors.push_back(r + i);
+ }
+
+ gettimeofday(&begin, nullptr);
+ ret = vt_estimator_.Estimate(thermistors, &output);
+ gettimeofday(&end, nullptr);
+ if (ret != thermal::vtestimator::kVtEstimatorOk) {
+ std::cout << "Failed to run estimator (ret: " << ret << ")\n";
+ return -1;
+ }
+
+ unsigned long inference_time_usec = get_elapsed_time_usec(begin, end);
+
+ inference_count++;
+ max_inference_time = std::max(max_inference_time, inference_time_usec);
+ min_inference_time = std::min(min_inference_time, inference_time_usec);
+ sum_inference_time += inference_time_usec;
+ avg_inference_time = sum_inference_time / inference_count;
+ inference_times.push_back(inference_time_usec);
+
+ unsigned long elapsed_time = get_elapsed_time_usec(start_loop_time, end);
+ if (elapsed_time - prev_log_time >= klog_interval_usec) {
+ std::cout << "elapsed_time_sec: " << elapsed_time / kmillion
+ << " inference_count: " << inference_count
+ << " min_inference_time: " << min_inference_time
+ << " max_inference_time: " << max_inference_time
+ << " avg_inference_time: " << avg_inference_time << std::endl;
+ prev_log_time = elapsed_time;
+ }
+
+ if (inference_delay_sec)
+ sleep(inference_delay_sec);
+ } while (inference_count < min_inference_count);
+
+ std::cout << "\n\ntotal inference count: " << inference_count << std::endl;
+ std::cout << "total inference time: " << sum_inference_time << std::endl;
+ std::cout << "avg_inference_time: " << avg_inference_time << std::endl;
+ std::cout << "min_inference_time: " << min_inference_time << std::endl;
+ std::cout << "max_inference_time: " << max_inference_time << std::endl;
+
+ std::sort(inference_times.begin(), inference_times.end());
+ std::cout << "\n\n";
+ std::cout << "p50: " << inference_times[inference_count * 0.5] << std::endl;
+ std::cout << "p90: " << inference_times[inference_count * 0.9] << std::endl;
+
+ return 0;
+}
+
+static int run_single_inference(std::string_view model_path, char *input) {
+ if (!input) {
+ std::cout << "input is nullptr" << std::endl;
+ return -1;
+ }
+
+ std::vector<float> thermistors;
+ char *ip = input;
+ char *saveptr;
+
+ std::cout << "Parsing thermistors from input string: ";
+ ip = strtok_r(ip, " ", &saveptr);
+ while (ip) {
+ float thermistor_value;
+
+ if (sscanf(ip, "%f", &thermistor_value) != 1) {
+ std::cout << "inputs parsing failed";
+ }
+
+ std::cout << thermistor_value << " ";
+ thermistors.push_back(thermistor_value);
+
+ ip = strtok_r(NULL, " ", &saveptr);
+ }
+ std::cout << std::endl;
+ std::cout << "thermistors.size(): " << thermistors.size() << std::endl;
+
+ float output;
+ thermal::vtestimator::VtEstimatorStatus ret;
+ thermal::vtestimator::VirtualTempEstimator vt_estimator_(thermistors.size());
+
+ std::cout << "Initialize estimator\n";
+ ret = vt_estimator_.Initialize(model_path.data());
+ if (ret != thermal::vtestimator::kVtEstimatorOk) {
+ std::cout << "Failed to Initialize estimator (ret: " << ret << ")\n";
+ return -1;
+ }
+
+ std::cout << "run estimator\n";
+ ret = vt_estimator_.Estimate(thermistors, &output);
+ if (ret != thermal::vtestimator::kVtEstimatorOk) {
+ std::cout << "Failed to run estimator (ret: " << ret << ")\n";
+ return -1;
+ }
+
+ std::cout << "output: " << output << std::endl;
+ return 0;
+}
+
+static int run_batch_process(std::string_view model_path, std::string_view thermal_config_path,
+ const char *input_file, const char *output_file) {
+ if (!input_file || !output_file) {
+ std::cout << "input and output files required for batch process\n";
+ return -1;
+ }
+
+ std::cout << "get_input_combination(): ";
+ std::vector<std::string> input_combination = get_input_combination(thermal_config_path.data());
+ if (input_combination.size() == 0) {
+ LOG(ERROR) << "Invalid input_combination";
+ return -1;
+ }
+
+ thermal::vtestimator::VtEstimatorStatus ret;
+ thermal::vtestimator::VirtualTempEstimator vt_estimator_(input_combination.size());
+
+ std::cout << "Initialize estimator\n";
+ ret = vt_estimator_.Initialize(model_path.data());
+ if (ret != thermal::vtestimator::kVtEstimatorOk) {
+ std::cout << "Failed to Initialize estimator (ret: " << ret << ")\n";
+ return -1;
+ }
+
+ std::string json_doc;
+ if (!android::base::ReadFileToString(input_file, &json_doc)) {
+ LOG(ERROR) << "Failed to read JSON config from " << input_file;
+ return -1;
+ }
+ Json::Value root;
+ Json::CharReaderBuilder reader_builder;
+ std::unique_ptr<Json::CharReader> reader(reader_builder.newCharReader());
+ std::string errorMessage;
+
+ if (!reader->parse(&*json_doc.begin(), &*json_doc.end(), &root, &errorMessage)) {
+ LOG(ERROR) << "Failed to parse JSON config: " << errorMessage;
+ return -1;
+ }
+
+ std::cout << "Number of testcases " << root.size() << std::endl;
+
+ for (auto const &testcase_name : root.getMemberNames()) {
+ if (testcase_name == "Metadata") {
+ continue;
+ }
+
+ Json::Value testcase = root[testcase_name];
+ Json::Value model_vt_outputs;
+ int loop_count = testcase[input_combination[0]].size();
+
+ std::cout << "tc: " << testcase_name << " count: " << loop_count << std::endl;
+ for (int i = 0; i < loop_count; ++i) {
+ std::vector<float> model_inputs;
+ float model_output;
+ int num_inputs = input_combination.size();
+
+ for (int j = 0; j < num_inputs; ++j) {
+ std::string input_name = input_combination[j];
+ std::string value_str = testcase[input_name][std::to_string(i)].asString();
+
+ std::cout << "tc[" << testcase_name << "] entry[" << i << "] input[" << input_name
+ << "] value_str[" << value_str << "]\n";
+
+ float value;
+ if (android::base::ParseFloat(value_str, &value) == false) {
+ std::cout << "Failed to parse value_str : " << value_str << " to float\n";
+ }
+
+ model_inputs.push_back(value);
+ }
+
+ ret = vt_estimator_.Estimate(model_inputs, &model_output);
+ if (ret != thermal::vtestimator::kVtEstimatorOk) {
+ std::cout << "Failed to run estimator (ret: " << ret << ")\n";
+ return -1;
+ }
+
+ model_vt_outputs[std::to_string(i)] = std::to_string(model_output);
+ }
+
+ testcase["model_vt"] = model_vt_outputs;
+ root[testcase_name] = testcase;
+ std::cout << "completed testcase_name: " << testcase_name << std::endl;
+ }
+
+ Json::StreamWriterBuilder writer_builder;
+ writer_builder["indentation"] = "";
+ std::unique_ptr<Json::StreamWriter> writer(writer_builder.newStreamWriter());
+ std::ofstream output_stream(output_file, std::ofstream::out);
+ writer->write(root, &output_stream);
+
+ return 0;
+}
+
+void print_usage() {
+ std::string message = "usage: \n";
+ message += "-m : input mode (";
+ message += "0: single inference ";
+ message += "1: json input file ";
+ message += "2: generate random inputs) \n";
+ message += "-p : path to model file \n";
+ message += "-t : path to thermal config file \n";
+ message += "-i : input samples (mode 0), path to input file (mode 1) \n";
+ message += "-o : output file (mode 1) \n";
+ message += "-d : delay between inferences in seconds (mode 2) \n";
+ message += "-c : inference count (mode 2)";
+
+ std::cout << message << std::endl;
+}
+
+int main(int argc, char *argv[]) {
+ int c, mode = -1;
+ char *input = nullptr, *output = nullptr;
+ std::string model_path, thermal_config_path;
+ int min_inference_count = -1;
+ int inference_delay_sec = 0;
+
+ while ((c = getopt(argc, argv, "hm:p:i:c:o:d:t:")) != -1) switch (c) {
+ case 'm':
+ mode = atoi(optarg);
+ std::cout << "mode: " << mode << std::endl;
+ break;
+ case 'p':
+ model_path = optarg;
+ std::cout << "model_path: " << model_path << std::endl;
+ break;
+ case 't':
+ thermal_config_path = optarg;
+ std::cout << "thermal_config_path: " << thermal_config_path << std::endl;
+ break;
+ case 'i':
+ input = optarg;
+ std::cout << "input: " << input << std::endl;
+ break;
+ case 'o':
+ output = optarg;
+ std::cout << "output: " << output << std::endl;
+ break;
+ case 'c':
+ min_inference_count = atoi(optarg);
+ std::cout << "min_inference_count: " << min_inference_count << std::endl;
+ break;
+ case 'd':
+ inference_delay_sec = atoi(optarg);
+ std::cout << "inference_delay_sec : " << inference_delay_sec << std::endl;
+ break;
+ case 'h':
+ print_usage();
+ return 0;
+ default:
+ std::cout << "unsupported option " << c << std::endl;
+ abort();
+ }
+
+ if (model_path.empty()) {
+ model_path = kDefaultModel;
+ std::cout << "Using default model_path: " << model_path << std::endl;
+ }
+
+ if (thermal_config_path.empty()) {
+ thermal_config_path =
+ "/vendor/etc/" +
+ android::base::GetProperty(kConfigProperty.data(), kConfigDefaultFileName.data());
+ std::cout << "Using default thermal config: " << thermal_config_path << std::endl;
+ }
+
+ int ret = -1;
+ switch (mode) {
+ case 0:
+ ret = run_single_inference(model_path, input);
+ break;
+ case 1:
+ ret = run_batch_process(model_path, thermal_config_path, input, output);
+ break;
+ case 2:
+ ret = run_random_input_inference(model_path, thermal_config_path, min_inference_count,
+ inference_delay_sec);
+ break;
+ default:
+ std::cout << "unsupported mode" << std::endl;
+ print_usage();
+ break;
+ }
+
+ std::cout << "Exiting" << std::endl;
+ fflush(stdout);
+
+ return ret;
+}