summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2019-10-29 11:51:41 -0700
committerandroid-build-merger <android-build-merger@google.com>2019-10-29 11:51:41 -0700
commitb34fce3c2a57b2e5148900f9699340ec63f671b6 (patch)
tree27d70e863b09ccb2471a08e63c835d2777c5f3a6
parentbd86742ff4244cd721652c4b52ee5a43426cac82 (diff)
parent9b84d74900feaa7cd32bd8dff90724f1dee8120b (diff)
downloadml-b34fce3c2a57b2e5148900f9699340ec63f671b6.tar.gz
Merge changes I032a5013,Ib1a95b75,I40584542,I026bad69,Icf98281b, ... am: 9c1d0b0442 am: cb060fcfee
am: 9b84d74900 Change-Id: I323d0920ec580027f71f7363da3da9f1e98c1f73
-rw-r--r--nn/runtime/BurstBuilder.h8
-rw-r--r--nn/runtime/ExecutionBuilder.cpp208
-rw-r--r--nn/runtime/ExecutionBuilder.h10
-rw-r--r--nn/runtime/ExecutionPlan.cpp9
-rw-r--r--nn/runtime/Manager.cpp216
-rw-r--r--nn/runtime/Manager.h13
-rw-r--r--nn/runtime/NeuralNetworks.cpp23
-rw-r--r--nn/runtime/TypeManager.cpp2
-rw-r--r--nn/runtime/VersionedInterfaces.cpp121
-rw-r--r--nn/runtime/VersionedInterfaces.h66
-rw-r--r--nn/runtime/test/TestCompilationCaching.cpp189
-rw-r--r--nn/runtime/test/TestIntrospectionControl.cpp6
-rw-r--r--nn/runtime/test/TestPartitioning.cpp18
13 files changed, 477 insertions, 412 deletions
diff --git a/nn/runtime/BurstBuilder.h b/nn/runtime/BurstBuilder.h
index bfb9a9a38..6a3ba783e 100644
--- a/nn/runtime/BurstBuilder.h
+++ b/nn/runtime/BurstBuilder.h
@@ -31,10 +31,10 @@ class CompilationBuilder;
* TODO: Could we "hide" the per-step burst controller instance inside
* StepExecutor? Today it's exposed as a "sibling" to StepExecutor:
* ExecutionPlan::next both generates a StepExecutor instance and finds a
- * pointer to a burst controller; and StepExecutor::startCompute is passed a
- * pointer to a burst controller. Instead, could ExecutionPlan::next stash the
- * burst controller in the StepExecutor, so that it doesn't have to be passed
- * to any of the StepExecutor methods?
+ * pointer to a burst controller; and StepExecutor::compute is passed a pointer
+ * to a burst controller. Instead, could ExecutionPlan::next stash the burst
+ * controller in the StepExecutor, so that it doesn't have to be passed to any
+ * of the StepExecutor methods?
*/
class BurstBuilder {
diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp
index 97e847b24..9e73118e9 100644
--- a/nn/runtime/ExecutionBuilder.cpp
+++ b/nn/runtime/ExecutionBuilder.cpp
@@ -303,12 +303,9 @@ int ExecutionBuilder::getOutputOperandRank(uint32_t index, uint32_t* rank) {
// For Q this is irrelevant: We only support timing in conjunction
// with an explicit device list; and we do not support CPU fallback
// with an explicit device list. See CompilationBuilder::mExplicitDeviceList.
-static int cpuFallbackFull(ExecutionBuilder* executionBuilder,
- sp<ExecutionCallback>* fallbackCallback) {
+static std::tuple<int, std::vector<OutputShape>, Timing> cpuFallbackFull(
+ ExecutionBuilder* executionBuilder) {
CHECK(executionBuilder != nullptr);
- CHECK(fallbackCallback != nullptr);
- *fallbackCallback = nullptr;
-
NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "cpuFallbackFull");
VLOG(EXECUTION) << "cpuFallbackFull";
@@ -318,58 +315,45 @@ static int cpuFallbackFull(ExecutionBuilder* executionBuilder,
executor.mapInputsAndOutputsTrivially();
// Attempt fallback execution.
- NN_RETURN_IF_ERROR(executor.startComputeOnCpuFallback(fallbackCallback));
- CHECK(*fallbackCallback != nullptr);
- (*fallbackCallback)->wait();
- return ANEURALNETWORKS_NO_ERROR;
+ return executor.computeOnCpuFallback();
}
// Attempt synchronous execution on CPU.
-// fallbackExecutor is non-null i.f.f. ANEURALNETWORKS_NO_ERROR is returned.
-// fallbackCallback is non-null i.f.f. ANEURALNETWORKS_NO_ERROR is returned.
// TODO: How should we handle timing in this case?
// For Q this is irrelevant: We only support timing in conjunction
// with an explicit device list; and we do not support CPU fallback
// with an explicit device list. See CompilationBuilder::mExplicitDeviceList.
-static int cpuFallbackPartial(const ExecutionPlan* plan,
- std::shared_ptr<ExecutionPlan::Controller> controller,
- std::shared_ptr<StepExecutor>* fallbackExecutor,
- sp<ExecutionCallback>* fallbackCallback) {
- CHECK(plan != nullptr);
- CHECK(fallbackExecutor != nullptr);
- *fallbackExecutor = nullptr;
- CHECK(fallbackCallback != nullptr);
- *fallbackCallback = nullptr;
-
+static std::tuple<int, std::vector<OutputShape>, Timing, std::shared_ptr<StepExecutor>>
+cpuFallbackPartial(const ExecutionPlan& plan,
+ std::shared_ptr<ExecutionPlan::Controller> controller) {
NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "cpuFallbackPartial");
VLOG(EXECUTION) << "cpuFallbackPartial";
// Get fallback executor.
std::shared_ptr<StepExecutor> executor;
- NN_RETURN_IF_ERROR(plan->fallback(controller, &executor));
+ int n1 = plan.fallback(controller, &executor);
+ if (n1 != ANEURALNETWORKS_NO_ERROR) {
+ return {n1, {}, kNoTiming, nullptr};
+ }
CHECK(executor != nullptr);
// Attempt fallback execution.
- NN_RETURN_IF_ERROR(executor->startComputeOnCpuFallback(fallbackCallback));
- CHECK(*fallbackCallback != nullptr);
- (*fallbackCallback)->wait();
- *fallbackExecutor = executor;
- return ANEURALNETWORKS_NO_ERROR;
+ auto [n2, outputShapes, timing] = executor->computeOnCpuFallback();
+ return {n2, std::move(outputShapes), timing, executor};
}
static void asyncStartComputePartitioned(ExecutionBuilder* executionBuilder,
- const ExecutionPlan* plan,
+ const ExecutionPlan& plan,
std::shared_ptr<ExecutionPlan::Controller> controller,
bool allowFallback,
const sp<ExecutionCallback>& executionCallback) {
CHECK(executionBuilder != nullptr);
- CHECK(plan != nullptr);
VLOG(EXECUTION) << "ExecutionBuilder::compute (from plan, iteratively)";
- std::vector<OutputShape> outputShapes;
+
+ std::vector<OutputShape> outputShapes = executionBuilder->getInitialOutputShapes();
Timing timing = kNoTiming;
// Disallow fallback when the ExecutionPlan is simple on CPU.
- allowFallback &= !plan->isSimpleCpu();
- executionBuilder->initializeOutputShapes(&outputShapes);
+ allowFallback &= !plan.isSimpleCpu();
while (true) {
VLOG(EXECUTION) << "looking for next StepExecutor";
@@ -377,7 +361,7 @@ static void asyncStartComputePartitioned(ExecutionBuilder* executionBuilder,
// Get the current step of the execution.
std::shared_ptr<StepExecutor> executor;
std::shared_ptr<ExecutionBurstController> burstController;
- int n = plan->next(controller, &executor, &burstController);
+ int n = plan.next(controller, &executor, &burstController);
if (n != ANEURALNETWORKS_NO_ERROR) {
if (allowFallback) break;
executionCallback->notify(convertResultCodeToErrorStatus(n), {}, kNoTiming);
@@ -393,99 +377,75 @@ static void asyncStartComputePartitioned(ExecutionBuilder* executionBuilder,
const bool executorIsCpu = executor->isCpu();
// Attempt to execute a single step of the execution.
- sp<ExecutionCallback> stepCallback;
- n = executor->startCompute(&stepCallback, burstController);
+ auto [stepN, stepOutputShapes, stepTiming] = executor->compute(burstController);
- // Immediately end execution if there was an error and fallback is not
- // allowed.
- if (n != ANEURALNETWORKS_NO_ERROR && !allowFallback) {
- executionCallback->notify(convertResultCodeToErrorStatus(n), {}, kNoTiming);
- return;
+ // Update global outputs.
+ if (!executor->updateOutputShapes(stepOutputShapes, &outputShapes)) {
+ stepN = ANEURALNETWORKS_OP_FAILED;
}
- // If execution successfully launched, process the execution.
- if (n == ANEURALNETWORKS_NO_ERROR) {
- stepCallback->wait();
- ErrorStatus status = stepCallback->getStatus();
- const auto& stepOutputShapes = stepCallback->getOutputShapes();
-
- // Update global outputs.
- if (!executor->updateOutputShapes(stepOutputShapes, &outputShapes)) {
- status = ErrorStatus::GENERAL_FAILURE;
- }
-
- // If execution was successful, continue to next step.
- if (status == ErrorStatus::NONE) {
- // We only support collection of timing information in the case of a
- // single step, so it's safe to just keep track of the last step's
- // timing information.
- timing = stepCallback->getTiming();
- continue;
- }
-
- // OUTPUT_INSUFFICIENT_SIZE is not recoverable, so end execution.
- if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
- executionCallback->notify(status, outputShapes, kNoTiming);
- return;
- }
+ // If execution was successful, continue to next step.
+ if (stepN == ANEURALNETWORKS_NO_ERROR) {
+ // We only support collection of timing information in the case of a
+ // single step, so it's safe to just keep track of the last step's
+ // timing information.
+ timing = stepTiming;
+ continue;
+ }
- // If fallback is not allowed and there was an error, end execution.
- if (!allowFallback) {
- executionCallback->notify(status, {}, kNoTiming);
- return;
- }
+ // OUTPUT_INSUFFICIENT_SIZE is not recoverable, so end execution.
+ if (stepN == ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE) {
+ const ErrorStatus stepStatus = convertResultCodeToErrorStatus(stepN);
+ executionCallback->notify(stepStatus, outputShapes, kNoTiming);
+ return;
+ }
- // Propagate error to fallback path.
- n = convertErrorStatusToResultCode(status);
+ // If fallback is not allowed and there was an error, end execution.
+ if (!allowFallback) {
+ const ErrorStatus stepStatus = convertResultCodeToErrorStatus(stepN);
+ executionCallback->notify(stepStatus, {}, kNoTiming);
+ return;
}
// If CPU execution was already attempted, either:
// (1) perform a full fallback if the plan is not simple, or
// (2) return from the function with an error
if (executorIsCpu) {
- if (!plan->isSimple()) break;
- executionCallback->notify(convertResultCodeToErrorStatus(n), {}, kNoTiming);
+ if (!plan.isSimple()) break;
+ executionCallback->notify(convertResultCodeToErrorStatus(stepN), {}, kNoTiming);
return;
}
// If the code reaches this point, attempt a partial fallback to CPU.
CHECK(allowFallback);
- std::shared_ptr<StepExecutor> fallbackExecutor;
- sp<ExecutionCallback> fallbackCallback;
- n = cpuFallbackPartial(plan, controller, &fallbackExecutor, &fallbackCallback);
-
- // Immediately fallback to full CPU execution if there was an error with
- // the partial CPU fallback.
- if (n != ANEURALNETWORKS_NO_ERROR) {
- break;
- }
-
- // Get fallback execution results.
- ErrorStatus fallbackStatus = fallbackCallback->getStatus();
- const auto& fallbackOutputShapes = fallbackCallback->getOutputShapes();
+ auto [fallbackN, fallbackOutputShapes, fallbackTiming, fallbackExecutor] =
+ cpuFallbackPartial(plan, controller);
// Update global outputs.
- if (!fallbackExecutor->updateOutputShapes(fallbackOutputShapes, &outputShapes)) {
- fallbackStatus = ErrorStatus::GENERAL_FAILURE;
+ if (fallbackExecutor != nullptr &&
+ !fallbackExecutor->updateOutputShapes(fallbackOutputShapes, &outputShapes)) {
+ fallbackN = ANEURALNETWORKS_OP_FAILED;
}
// If execution was successful, continue to next step.
- if (fallbackStatus == ErrorStatus::NONE) {
+ if (fallbackN == ANEURALNETWORKS_NO_ERROR) {
// We only support collection of timing information in the case of a
// single step, so it's safe to just keep track of the last step's
// timing information.
- timing = fallbackCallback->getTiming();
+ timing = fallbackTiming;
continue;
}
// OUTPUT_INSUFFICIENT_SIZE is not recoverable, so end execution.
- if (fallbackStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
+ if (fallbackN == ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE) {
+ const ErrorStatus fallbackStatus = convertResultCodeToErrorStatus(fallbackN);
executionCallback->notify(fallbackStatus, outputShapes, kNoTiming);
return;
}
// Do not fallback twice if the ExecutionPlan is simple.
- if (plan->isSimple()) {
+ if (plan.isSimple()) {
+ const ErrorStatus fallbackStatus = convertResultCodeToErrorStatus(fallbackN);
executionCallback->notify(fallbackStatus, {}, kNoTiming);
return;
}
@@ -498,14 +458,9 @@ static void asyncStartComputePartitioned(ExecutionBuilder* executionBuilder,
// If the code has reached this point, a potentially recoverable error
// occurred during the step executions. Instead, do a full execution
// fallback on the CPU.
- sp<ExecutionCallback> fallbackCallback;
- int n = cpuFallbackFull(executionBuilder, &fallbackCallback);
- if (n != ANEURALNETWORKS_NO_ERROR) {
- executionCallback->notify(convertResultCodeToErrorStatus(n), {}, kNoTiming);
- return;
- }
- executionCallback->notify(fallbackCallback->getStatus(), fallbackCallback->getOutputShapes(),
- fallbackCallback->getTiming());
+ auto [fullN, fullOutputShapes, fullTiming] = cpuFallbackFull(executionBuilder);
+ const ErrorStatus fullStatus = convertResultCodeToErrorStatus(fullN);
+ executionCallback->notify(fullStatus, fullOutputShapes, fullTiming);
}
int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback,
@@ -558,7 +513,7 @@ int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback,
VLOG(EXECUTION) << "ExecutionBuilder::compute (synchronous API)";
sp<ExecutionCallback> localSynchronizationCallback = new ExecutionCallback();
localSynchronizationCallback->setOnFinish(wrappedFinish);
- asyncStartComputePartitioned(this, mPlan, controller, allowFallback,
+ asyncStartComputePartitioned(this, *mPlan, controller, allowFallback,
localSynchronizationCallback);
localSynchronizationCallback->wait();
if (mMeasureTiming) {
@@ -579,24 +534,28 @@ int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback,
executionCallback->setOnFinish(wrappedFinish);
if (DeviceManager::get()->syncExecRuntime()) {
VLOG(EXECUTION) << "ExecutionBuilder::compute (asynchronous API, non-threaded)";
- asyncStartComputePartitioned(this, mPlan, controller, allowFallback, executionCallback);
+ asyncStartComputePartitioned(this, *mPlan, controller, allowFallback,
+ executionCallback);
} else {
VLOG(EXECUTION) << "ExecutionBuilder::compute (asynchronous API)";
- std::thread thread(asyncStartComputePartitioned, this, mPlan, controller, allowFallback,
- executionCallback);
- executionCallback->bindThread(std::move(thread));
+ std::thread asyncExecution([this, controller, allowFallback, executionCallback] {
+ asyncStartComputePartitioned(this, *mPlan, controller, allowFallback,
+ executionCallback);
+ });
+ executionCallback->bindThread(std::move(asyncExecution));
}
*synchronizationCallback = executionCallback;
return ANEURALNETWORKS_NO_ERROR;
}
}
-void ExecutionBuilder::initializeOutputShapes(std::vector<OutputShape>* outputShapes) const {
- outputShapes->resize(mOutputs.size());
- for (uint32_t i = 0; i < mOutputs.size(); i++) {
- (*outputShapes)[i].dimensions = mOutputs[i].dimensions;
- (*outputShapes)[i].isSufficient = true;
- }
+std::vector<OutputShape> ExecutionBuilder::getInitialOutputShapes() const {
+ std::vector<OutputShape> outputShapes(mOutputs.size());
+ std::transform(mOutputs.begin(), mOutputs.end(), outputShapes.begin(),
+ [](const auto& x) -> OutputShape {
+ return {.dimensions = x.dimensions, .isSufficient = true};
+ });
+ return outputShapes;
}
// Check if the dimensions "to" is updatable by dimensions "from", where "from" must
@@ -741,11 +700,9 @@ bool StepExecutor::isCpu() const {
return mDevice == DeviceManager::getCpuDevice();
}
-int StepExecutor::startCompute(sp<ExecutionCallback>* synchronizationCallback,
- const std::shared_ptr<ExecutionBurstController>& burstController) {
+std::tuple<int, std::vector<OutputShape>, Timing> StepExecutor::compute(
+ const std::shared_ptr<ExecutionBurstController>& burstController) {
CHECK(mPreparedModel != nullptr);
- CHECK(synchronizationCallback != nullptr);
- *synchronizationCallback = nullptr;
if (VLOG_IS_ON(EXECUTION)) {
logArguments("input", mInputs);
@@ -757,19 +714,12 @@ int StepExecutor::startCompute(sp<ExecutionCallback>* synchronizationCallback,
mPreparedModel->execute(mInputs, mOutputs, mMemories, burstController, measure);
mExecutionBuilder->reportTiming(timing);
- if (n != ANEURALNETWORKS_NO_ERROR && n != ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE) {
- return n;
- }
-
- const ErrorStatus status = convertResultCodeToErrorStatus(n);
- *synchronizationCallback = new ExecutionCallback();
- (*synchronizationCallback)->notify_1_2(status, outputShapes, timing);
- return ANEURALNETWORKS_NO_ERROR;
+ return {n, std::move(outputShapes), timing};
}
// For cpuFallback{Partial,Full}, recompile the model on CPU and then start compute.
-int StepExecutor::startComputeOnCpuFallback(sp<ExecutionCallback>* synchronizationCallback) {
- NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "StepExecutor::startComputeOnCpuFallback");
+std::tuple<int, std::vector<OutputShape>, Timing> StepExecutor::computeOnCpuFallback() {
+ NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "StepExecutor::computeOnCpuFallback");
VLOG(EXECUTION) << "Re-compile the model on CPU";
mDevice = DeviceManager::getCpuDevice();
mPreparedModel = nullptr;
@@ -780,8 +730,10 @@ int StepExecutor::startComputeOnCpuFallback(sp<ExecutionCallback>* synchronizati
static_cast<ExecutionPreference>(ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER);
const auto [n, preparedModel] = mDevice->prepareModel(makeModel, preference, {}, {});
mPreparedModel = preparedModel;
- NN_RETURN_IF_ERROR(n);
- return startCompute(synchronizationCallback, /*burstController=*/nullptr);
+ if (n != ANEURALNETWORKS_NO_ERROR) {
+ return {n, {}, kNoTiming};
+ }
+ return compute(/*burstController=*/nullptr);
}
} // namespace nn
diff --git a/nn/runtime/ExecutionBuilder.h b/nn/runtime/ExecutionBuilder.h
index f07335a01..3d8ab3e6c 100644
--- a/nn/runtime/ExecutionBuilder.h
+++ b/nn/runtime/ExecutionBuilder.h
@@ -69,7 +69,7 @@ class ExecutionBuilder {
int burstCompute(BurstBuilder* burst) { return compute(nullptr, burst); }
// Initialize output dimensional information from ModelArgumentInfo.
- void initializeOutputShapes(std::vector<hal::OutputShape>* outputShapes) const;
+ std::vector<hal::OutputShape> getInitialOutputShapes() const;
int getOutputOperandDimensions(uint32_t index, uint32_t* dimensions);
int getOutputOperandRank(uint32_t index, uint32_t* rank);
@@ -160,7 +160,7 @@ class StepExecutor {
// is executing the entire model from the ExecutionBuilder).
void mapInputsAndOutputsTrivially();
- // Update output shapes returned from ExecutionCallback to ExecutionBuilder.
+ // Update output shapes with shapes returned from execution.
bool updateOutputShapes(const std::vector<hal::OutputShape>& from,
std::vector<hal::OutputShape>* to);
@@ -189,12 +189,12 @@ class StepExecutor {
}
// Executes using the (driver, preparedModel) specified at construction time.
- int startCompute(sp<ExecutionCallback>* synchronizationCallback,
- const std::shared_ptr<ExecutionBurstController>& burstController = nullptr);
+ std::tuple<int, std::vector<hal::OutputShape>, hal::Timing> compute(
+ const std::shared_ptr<ExecutionBurstController>& burstController = nullptr);
// Re-compiles and executes using the CPU, regardless of the (driver,
// preparedModel) specified at construction time.
- int startComputeOnCpuFallback(sp<ExecutionCallback>* synchronizationCallback);
+ std::tuple<int, std::vector<hal::OutputShape>, hal::Timing> computeOnCpuFallback();
bool isCpu() const;
diff --git a/nn/runtime/ExecutionPlan.cpp b/nn/runtime/ExecutionPlan.cpp
index 901305216..ed38aec2f 100644
--- a/nn/runtime/ExecutionPlan.cpp
+++ b/nn/runtime/ExecutionPlan.cpp
@@ -71,8 +71,9 @@ int compile(const Device& device, const ModelBuilder& model, int executionPrefer
*preparedModel = nullptr;
std::optional<CacheToken> cacheToken;
- if (device.isCachingSupported() && token->ok() && token->updateFromString(device.getName()) &&
- token->updateFromString(device.getVersionString()) &&
+ if (device.isCachingSupported() && token->ok() &&
+ token->updateFromString(device.getName().c_str()) &&
+ token->updateFromString(device.getVersionString().c_str()) &&
token->update(&executionPreference, sizeof(executionPreference)) && token->finish()) {
cacheToken.emplace(token->getCacheToken());
}
@@ -996,13 +997,13 @@ class CanDo {
CanDo() {}
void initialize(const MetaModel& metaModel, std::shared_ptr<Device> device) {
- device->getSupportedOperations(metaModel, &mSupportsOperationByIndex);
+ mSupportsOperationByIndex = device->getSupportedOperations(metaModel);
}
bool check(size_t operationIndex) const { return mSupportsOperationByIndex[operationIndex]; }
private:
- hidl_vec<bool> mSupportsOperationByIndex;
+ std::vector<bool> mSupportsOperationByIndex;
};
} // anonymous namespace
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index 34378b3fd..1be1bf68b 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -52,28 +52,38 @@ const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX
// A Device with actual underlying driver
class DriverDevice : public Device {
public:
- DriverDevice(std::string name, const sp<V1_0::IDevice>& device);
-
- // Returns true if succesfully initialized.
- bool initialize();
-
- const char* getName() const override { return mName.c_str(); }
- const char* getVersionString() const override { return mVersionString.c_str(); }
- int64_t getFeatureLevel() const override { return mInterface->getFeatureLevel(); }
- int32_t getType() const override { return mInterface->getType(); }
- hidl_vec<Extension> getSupportedExtensions() const override;
- void getSupportedOperations(const MetaModel& metaModel,
- hidl_vec<bool>* supportedOperations) const override;
- PerformanceInfo getPerformance(OperandType type) const override;
+ // Create a DriverDevice from a name and an IDevice.
+ // Returns nullptr on failure.
+ static std::shared_ptr<DriverDevice> create(std::string name, sp<V1_0::IDevice> device);
+
+ // Prefer using DriverDevice::create
+ DriverDevice(std::shared_ptr<VersionedIDevice> device);
+
+ const std::string& getName() const override { return kInterface->getName(); }
+ const std::string& getVersionString() const override { return kInterface->getVersionString(); }
+ int64_t getFeatureLevel() const override { return kInterface->getFeatureLevel(); }
+ int32_t getType() const override { return kInterface->getType(); }
+ const std::vector<Extension>& getSupportedExtensions() const override {
+ return kInterface->getSupportedExtensions();
+ }
+ std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const override;
+ PerformanceInfo getPerformance(OperandType type) const override {
+ const auto& capabilities = kInterface->getCapabilities();
+ return lookup(capabilities.operandPerformance, type);
+ }
PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const override {
- return mCapabilities.relaxedFloat32toFloat16PerformanceScalar;
+ const auto& capabilities = kInterface->getCapabilities();
+ return capabilities.relaxedFloat32toFloat16PerformanceScalar;
}
PerformanceInfo getRelaxedFloat32toFloat16PerformanceTensor() const override {
- return mCapabilities.relaxedFloat32toFloat16PerformanceTensor;
+ const auto& capabilities = kInterface->getCapabilities();
+ return capabilities.relaxedFloat32toFloat16PerformanceTensor;
}
bool isCachingSupported() const override {
// Caching is supported if either of numModelCache or numDataCache is greater than 0.
- return mNumCacheFiles.first > 0 || mNumCacheFiles.second > 0;
+ const auto [numModelCacheFiles, numDataCacheFiles] =
+ kInterface->getNumberOfCacheFilesNeeded();
+ return numModelCacheFiles > 0 || numDataCacheFiles > 0;
}
std::pair<int, std::shared_ptr<PreparedModel>> prepareModel(
@@ -88,12 +98,7 @@ class DriverDevice : public Device {
std::pair<int, std::shared_ptr<PreparedModel>> prepareModelFromCacheInternal(
const std::string& cacheDir, const CacheToken& token) const;
- std::string mName;
- std::string mVersionString;
- const std::shared_ptr<VersionedIDevice> mInterface;
- Capabilities mCapabilities;
- hidl_vec<Extension> mSupportedExtensions;
- std::pair<uint32_t, uint32_t> mNumCacheFiles;
+ const std::shared_ptr<VersionedIDevice> kInterface;
#ifdef NN_DEBUGGABLE
// For debugging: behavior of IDevice::getSupportedOperations for SampleDriver.
@@ -124,104 +129,57 @@ class DriverPreparedModel : public PreparedModel {
const std::shared_ptr<VersionedIPreparedModel> mPreparedModel;
};
-DriverDevice::DriverDevice(std::string name, const sp<V1_0::IDevice>& device)
- : mName(std::move(name)), mInterface(VersionedIDevice::create(mName, device)) {}
-
-// TODO: handle errors from initialize correctly
-bool DriverDevice::initialize() {
+DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device)
+ : kInterface(std::move(device)) {
#ifdef NN_DEBUGGABLE
static const char samplePrefix[] = "sample";
-
- mSupported = (mName.substr(0, sizeof(samplePrefix) - 1) == samplePrefix)
- ? getProp("debug.nn.sample.supported")
- : 0;
-#endif // NN_DEBUGGABLE
-
- ErrorStatus status = ErrorStatus::GENERAL_FAILURE;
-
- if (mInterface == nullptr) {
- LOG(ERROR) << "DriverDevice contains invalid interface object.";
- return false;
- }
-
- std::tie(status, mCapabilities) = mInterface->getCapabilities();
- if (status != ErrorStatus::NONE) {
- LOG(ERROR) << "IDevice::getCapabilities returned the error " << toString(status);
- return false;
- }
- VLOG(MANAGER) << "Capab " << toString(mCapabilities);
-
- std::tie(status, mVersionString) = mInterface->getVersionString();
- // TODO(miaowang): add a validation test case for in case of error.
- if (status != ErrorStatus::NONE) {
- LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(status);
- return false;
- }
-
- std::tie(status, mSupportedExtensions) = mInterface->getSupportedExtensions();
- if (status != ErrorStatus::NONE) {
- LOG(ERROR) << "IDevice::getSupportedExtensions returned the error " << toString(status);
- return false;
+ if (getName().substr(0, sizeof(samplePrefix) - 1) == samplePrefix) {
+ mSupported = getProp("debug.nn.sample.supported");
}
+#endif // NN_DEBUGGABLE
+}
- std::tie(status, mNumCacheFiles.first, mNumCacheFiles.second) =
- mInterface->getNumberOfCacheFilesNeeded();
- if (status != ErrorStatus::NONE) {
- LOG(WARNING) << "IDevice::getNumberOfCacheFilesNeeded returned the error "
- << toString(status);
- mNumCacheFiles = {0, 0};
- }
- if (mNumCacheFiles.first > static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) ||
- mNumCacheFiles.second > static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES)) {
- LOG(WARNING)
- << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files "
- "numModelCache = "
- << mNumCacheFiles.first << ", numDataCache = " << mNumCacheFiles.second;
- mNumCacheFiles = {0, 0};
+std::shared_ptr<DriverDevice> DriverDevice::create(std::string name, sp<V1_0::IDevice> device) {
+ CHECK(device != nullptr);
+ std::shared_ptr<VersionedIDevice> versionedDevice =
+ VersionedIDevice::create(name, std::move(device));
+ if (versionedDevice == nullptr) {
+ LOG(ERROR) << "DriverDevice::create failed to create VersionedIDevice object for service "
+ << name;
+ return nullptr;
}
- return true;
-}
-hidl_vec<Extension> DriverDevice::getSupportedExtensions() const {
- return mSupportedExtensions;
+ return std::make_shared<DriverDevice>(std::move(versionedDevice));
}
-void DriverDevice::getSupportedOperations(const MetaModel& metaModel,
- hidl_vec<bool>* outSupportedOperations) const {
+std::vector<bool> DriverDevice::getSupportedOperations(const MetaModel& metaModel) const {
// Query the driver for what it can do.
ErrorStatus status = ErrorStatus::GENERAL_FAILURE;
- hidl_vec<bool> supportedOperations;
- std::tie(status, supportedOperations) = mInterface->getSupportedOperations(metaModel);
+ std::vector<bool> supportedOperations;
+ std::tie(status, supportedOperations) = kInterface->getSupportedOperations(metaModel);
const Model& hidlModel = metaModel.getModel();
if (status != ErrorStatus::NONE) {
LOG(ERROR) << "IDevice::getSupportedOperations returned the error " << toString(status);
// Set the supported operation vectors to all false, so we won't use this driver.
- outSupportedOperations->resize(hidlModel.operations.size());
- std::fill(outSupportedOperations->begin(), outSupportedOperations->end(), false);
- return;
+ return std::vector<bool>(hidlModel.operations.size(), false);
}
if (supportedOperations.size() != hidlModel.operations.size()) {
LOG(ERROR) << "IDevice::getSupportedOperations returned a vector of length "
<< supportedOperations.size() << " when expecting "
<< hidlModel.operations.size();
// Set the supported operation vectors to all false, so we won't use this driver.
- outSupportedOperations->resize(hidlModel.operations.size());
- std::fill(outSupportedOperations->begin(), outSupportedOperations->end(), false);
- return;
+ return std::vector<bool>(hidlModel.operations.size(), false);
}
- *outSupportedOperations = std::move(supportedOperations);
-
#ifdef NN_DEBUGGABLE
if (mSupported != 1) {
- return;
+ return supportedOperations;
}
- const uint32_t baseAccumulator = std::hash<std::string>{}(mName);
- for (size_t operationIndex = 0; operationIndex < outSupportedOperations->size();
- operationIndex++) {
- if (!(*outSupportedOperations)[operationIndex]) {
+ const uint32_t baseAccumulator = std::hash<std::string>{}(getName());
+ for (size_t operationIndex = 0; operationIndex < supportedOperations.size(); operationIndex++) {
+ if (!supportedOperations[operationIndex]) {
continue;
}
@@ -245,14 +203,12 @@ void DriverDevice::getSupportedOperations(const MetaModel& metaModel,
accumulateOperands(operation.inputs);
accumulateOperands(operation.outputs);
if (accumulator & 1) {
- (*outSupportedOperations)[operationIndex] = false;
+ supportedOperations[operationIndex] = false;
}
}
#endif // NN_DEBUGGABLE
-}
-PerformanceInfo DriverDevice::getPerformance(OperandType type) const {
- return lookup(mCapabilities.operandPerformance, type);
+ return supportedOperations;
}
// Opens cache file by filename and sets the handle to the opened fd. Returns false on fail. The
@@ -322,7 +278,7 @@ static bool getCacheHandles(const std::string& cacheDir, const CacheToken& token
static std::pair<int, std::shared_ptr<PreparedModel>> prepareModelCheck(
ErrorStatus status, const std::shared_ptr<VersionedIPreparedModel>& preparedModel,
- const char* prepareName, const char* serviceName) {
+ const char* prepareName, const std::string& serviceName) {
if (status != ErrorStatus::NONE) {
LOG(ERROR) << prepareName << " on " << serviceName << " failed: "
<< "prepareReturnStatus=" << toString(status);
@@ -344,7 +300,7 @@ std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelInterna
hidl_vec<hidl_handle> modelCache, dataCache;
if (!maybeToken.has_value() ||
- !getCacheHandles(cacheDir, *maybeToken, mNumCacheFiles,
+ !getCacheHandles(cacheDir, *maybeToken, kInterface->getNumberOfCacheFilesNeeded(),
/*createIfNotExist=*/true, &modelCache, &dataCache)) {
modelCache.resize(0);
dataCache.resize(0);
@@ -353,7 +309,7 @@ std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelInterna
static const CacheToken kNullToken{};
const CacheToken token = maybeToken.value_or(kNullToken);
const auto [status, preparedModel] =
- mInterface->prepareModel(model, preference, modelCache, dataCache, token);
+ kInterface->prepareModel(model, preference, modelCache, dataCache, token);
return prepareModelCheck(status, preparedModel, "prepareModel", getName());
}
@@ -365,13 +321,13 @@ std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelFromCac
VLOG(COMPILATION) << "prepareModelFromCache";
hidl_vec<hidl_handle> modelCache, dataCache;
- if (!getCacheHandles(cacheDir, token, mNumCacheFiles,
+ if (!getCacheHandles(cacheDir, token, kInterface->getNumberOfCacheFilesNeeded(),
/*createIfNotExist=*/false, &modelCache, &dataCache)) {
return {ANEURALNETWORKS_OP_FAILED, nullptr};
}
const auto [status, preparedModel] =
- mInterface->prepareModelFromCache(modelCache, dataCache, token);
+ kInterface->prepareModelFromCache(modelCache, dataCache, token);
return prepareModelCheck(status, preparedModel, "prepareModelFromCache", getName());
}
@@ -571,13 +527,14 @@ class CpuDevice : public Device {
return instance;
}
- const char* getName() const override { return kName.c_str(); }
- const char* getVersionString() const override { return kVersionString.c_str(); }
+ const std::string& getName() const override { return kName; }
+ const std::string& getVersionString() const override { return kVersionString; }
int64_t getFeatureLevel() const override { return kFeatureLevel; }
int32_t getType() const override { return ANEURALNETWORKS_DEVICE_CPU; }
- hidl_vec<Extension> getSupportedExtensions() const override { return {/* No extensions. */}; }
- void getSupportedOperations(const MetaModel& metaModel,
- hidl_vec<bool>* supportedOperations) const override;
+ const std::vector<Extension>& getSupportedExtensions() const override {
+ return kSupportedExtensions;
+ }
+ std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const override;
PerformanceInfo getPerformance(OperandType) const override { return kPerformance; }
PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const override {
return kPerformance;
@@ -600,6 +557,7 @@ class CpuDevice : public Device {
// Since the performance is a ratio compared to the CPU performance,
// by definition the performance of the CPU is 1.0.
const PerformanceInfo kPerformance = {.execTime = 1.0f, .powerUsage = 1.0f};
+ const std::vector<Extension> kSupportedExtensions{/* No extensions. */};
};
// A special abstracted PreparedModel for the CPU, constructed by CpuDevice.
@@ -629,11 +587,10 @@ class CpuPreparedModel : public PreparedModel {
const std::vector<RunTimePoolInfo> mModelPoolInfos;
};
-void CpuDevice::getSupportedOperations(const MetaModel& metaModel,
- hidl_vec<bool>* supportedOperations) const {
+std::vector<bool> CpuDevice::getSupportedOperations(const MetaModel& metaModel) const {
const Model& hidlModel = metaModel.getModel();
const size_t count = hidlModel.operations.size();
- hidl_vec<bool> result(count);
+ std::vector<bool> result(count, false);
for (size_t i = 0; i < count; i++) {
// TODO(b/119870033): Decide whether and how post-P operations would be supported on CPU.
// We may want to use the slicer for CpuDevice just as we do for
@@ -642,7 +599,7 @@ void CpuDevice::getSupportedOperations(const MetaModel& metaModel,
result[i] = !isExtensionOperationType(operationType) &&
operationType != OperationType::OEM_OPERATION;
}
- *supportedOperations = std::move(result);
+ return result;
}
std::pair<int, std::shared_ptr<PreparedModel>> CpuDevice::prepareModel(
@@ -751,42 +708,33 @@ std::shared_ptr<Device> DeviceManager::getCpuDevice() {
std::shared_ptr<Device> DeviceManager::forTest_makeDriverDevice(const std::string& name,
const sp<V1_0::IDevice>& device) {
- auto driverDevice = std::make_shared<DriverDevice>(name, device);
- CHECK(driverDevice->initialize());
+ const auto driverDevice = DriverDevice::create(name, device);
+ CHECK(driverDevice != nullptr);
return driverDevice;
}
void DeviceManager::findAvailableDevices() {
- using ::android::hidl::manager::V1_2::IServiceManager;
VLOG(MANAGER) << "findAvailableDevices";
- sp<IServiceManager> manager = hardware::defaultServiceManager1_2();
- if (manager == nullptr) {
- LOG(ERROR) << "Unable to open defaultServiceManager";
- return;
+ // register driver devices
+ const auto names = hardware::getAllHalInstanceNames(V1_0::IDevice::descriptor);
+ for (const auto& name : names) {
+ VLOG(MANAGER) << "Found interface " << name;
+ sp<V1_0::IDevice> device = V1_0::IDevice::getService(name);
+ if (device == nullptr) {
+ LOG(ERROR) << "Got a null IDEVICE for " << name;
+ continue;
+ }
+ registerDevice(name, device);
}
- manager->listManifestByInterface(
- V1_0::IDevice::descriptor, [this](const hidl_vec<hidl_string>& names) {
- for (const auto& name : names) {
- VLOG(MANAGER) << "Found interface " << name.c_str();
- sp<V1_0::IDevice> device = V1_0::IDevice::getService(name);
- if (device == nullptr) {
- LOG(ERROR) << "Got a null IDEVICE for " << name.c_str();
- continue;
- }
- registerDevice(name.c_str(), device);
- }
- });
-
// register CPU fallback device
mDevices.push_back(CpuDevice::get());
mDevicesCpuOnly.push_back(CpuDevice::get());
}
-void DeviceManager::registerDevice(const char* name, const sp<V1_0::IDevice>& device) {
- auto d = std::make_shared<DriverDevice>(name, device);
- if (d->initialize()) {
+void DeviceManager::registerDevice(const std::string& name, const sp<V1_0::IDevice>& device) {
+ if (const auto d = DriverDevice::create(name, device)) {
mDevices.push_back(d);
}
}
diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h
index 82042d108..cece348ee 100644
--- a/nn/runtime/Manager.h
+++ b/nn/runtime/Manager.h
@@ -68,15 +68,14 @@ class Device {
virtual ~Device() = default;
// Introspection methods returning device information
- virtual const char* getName() const = 0;
- virtual const char* getVersionString() const = 0;
+ virtual const std::string& getName() const = 0;
+ virtual const std::string& getVersionString() const = 0;
virtual int64_t getFeatureLevel() const = 0;
virtual int32_t getType() const = 0;
- virtual hal::hidl_vec<hal::Extension> getSupportedExtensions() const = 0;
+ virtual const std::vector<hal::Extension>& getSupportedExtensions() const = 0;
// See the MetaModel class in MetaModel.h for more details.
- virtual void getSupportedOperations(const MetaModel& metaModel,
- hal::hidl_vec<bool>* supportedOperations) const = 0;
+ virtual std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const = 0;
virtual hal::PerformanceInfo getPerformance(hal::OperandType type) const = 0;
virtual hal::PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const = 0;
@@ -142,7 +141,7 @@ class DeviceManager {
}
// Register a test device.
- void forTest_registerDevice(const char* name, const sp<hal::V1_0::IDevice>& device) {
+ void forTest_registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device) {
registerDevice(name, device);
}
@@ -166,7 +165,7 @@ class DeviceManager {
DeviceManager();
// Adds a device for the manager to use.
- void registerDevice(const char* name, const sp<hal::V1_0::IDevice>& device);
+ void registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device);
void findAvailableDevices();
diff --git a/nn/runtime/NeuralNetworks.cpp b/nn/runtime/NeuralNetworks.cpp
index 17667fd50..3542c24b1 100644
--- a/nn/runtime/NeuralNetworks.cpp
+++ b/nn/runtime/NeuralNetworks.cpp
@@ -582,7 +582,7 @@ int ANeuralNetworksDevice_getName(const ANeuralNetworksDevice* device, const cha
return ANEURALNETWORKS_UNEXPECTED_NULL;
}
const Device* d = reinterpret_cast<const Device*>(device);
- *name = d->getName();
+ *name = d->getName().c_str();
return ANEURALNETWORKS_NO_ERROR;
}
@@ -592,7 +592,7 @@ int ANeuralNetworksDevice_getVersion(const ANeuralNetworksDevice* device, const
return ANEURALNETWORKS_UNEXPECTED_NULL;
}
const Device* d = reinterpret_cast<const Device*>(device);
- *version = d->getVersionString();
+ *version = d->getVersionString().c_str();
return ANEURALNETWORKS_NO_ERROR;
}
@@ -665,8 +665,7 @@ int ANeuralNetworksModel_getSupportedOperationsForDevices(
Device* d = reinterpret_cast<Device*>(const_cast<ANeuralNetworksDevice*>(devices[i]));
const MetaModel metaModel(hidlModel, DeviceManager::get()->strictSlicing());
- hidl_vec<bool> supportsByDevice;
- d->getSupportedOperations(metaModel, &supportsByDevice);
+ const std::vector<bool> supportsByDevice = d->getSupportedOperations(metaModel);
for (uint32_t j = 0; j < supportsByDevice.size(); j++) {
uint32_t originalIdx = opMap[j];
supportedOps[originalIdx] |= supportsByDevice[j];
@@ -1186,16 +1185,12 @@ int ANeuralNetworksDevice_getExtensionSupport(const ANeuralNetworksDevice* devic
return ANEURALNETWORKS_UNEXPECTED_NULL;
}
- Device* d = reinterpret_cast<Device*>(const_cast<ANeuralNetworksDevice*>(device));
- hidl_vec<Extension> supportedExtensions = d->getSupportedExtensions();
-
- *isExtensionSupported = false;
- for (const Extension& supportedExtension : supportedExtensions) {
- if (supportedExtension.name == extensionName) {
- *isExtensionSupported = true;
- break;
- }
- }
+ const Device* d = reinterpret_cast<const Device*>(device);
+ const auto& supportedExtensions = d->getSupportedExtensions();
+ *isExtensionSupported = std::any_of(supportedExtensions.begin(), supportedExtensions.end(),
+ [extensionName](const auto& supportedExtension) {
+ return supportedExtension.name == extensionName;
+ });
return ANEURALNETWORKS_NO_ERROR;
}
diff --git a/nn/runtime/TypeManager.cpp b/nn/runtime/TypeManager.cpp
index 854e28ab9..40b0e3d44 100644
--- a/nn/runtime/TypeManager.cpp
+++ b/nn/runtime/TypeManager.cpp
@@ -181,7 +181,7 @@ bool TypeManager::isExtensionsUseAllowed(const AppPackageInfo& appPackageInfo,
void TypeManager::findAvailableExtensions() {
for (const std::shared_ptr<Device>& device : mDeviceManager->getDrivers()) {
- for (const Extension extension : device->getSupportedExtensions()) {
+ for (const Extension& extension : device->getSupportedExtensions()) {
registerExtension(extension, device->getName());
}
}
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index ba6e2af7c..8da4e7c1f 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -306,21 +306,88 @@ std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExec
std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
sp<V1_0::IDevice> device) {
+ CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object.";
+
auto core = Core::create(std::move(device));
if (!core.has_value()) {
LOG(ERROR) << "VersionedIDevice::create failed to create Core.";
return nullptr;
}
- // return a valid VersionedIDevice object
- return std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value()));
+ // create and initialize a VersionedIDevice object
+ const auto versionedIDevice =
+ std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value()));
+ if (!versionedIDevice->initializeInternal()) {
+ LOG(ERROR) << "VersionedIDevice failed to initialize";
+ return nullptr;
+ }
+
+ // return a valid, initialized VersionedIDevice object
+ return versionedIDevice;
}
VersionedIDevice::VersionedIDevice(std::string serviceName, Core core)
: mServiceName(std::move(serviceName)), mCore(std::move(core)) {}
+bool VersionedIDevice::initializeInternal() {
+ auto [capabilitiesStatus, capabilities] = getCapabilitiesInternal();
+ if (capabilitiesStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IDevice::getCapabilities* returned the error "
+ << toString(capabilitiesStatus);
+ return false;
+ }
+ VLOG(MANAGER) << "Capab " << toString(capabilities);
+
+ const auto [versionStatus, versionString] = getVersionStringInternal();
+ // TODO(miaowang): add a validation test case for in case of error.
+ if (versionStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(versionStatus);
+ return false;
+ }
+
+ const int32_t type = getTypeInternal();
+ if (type == -1) {
+ LOG(ERROR) << "IDevice::getType returned an error";
+ return false;
+ }
+
+ const auto [extensionsStatus, supportedExtensions] = getSupportedExtensionsInternal();
+ if (extensionsStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IDevice::getSupportedExtensions returned the error "
+ << toString(extensionsStatus);
+ return false;
+ }
+
+ const auto [cacheFilesStatus, numModelCacheFiles, numDataCacheFiles] =
+ getNumberOfCacheFilesNeededInternal();
+ if (cacheFilesStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IDevice::getNumberOfCacheFilesNeeded returned the error "
+ << toString(cacheFilesStatus);
+ return false;
+ }
+
+ // The following limit is enforced by VTS
+ constexpr uint32_t maxNumCacheFiles =
+ static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES);
+ if (numModelCacheFiles > maxNumCacheFiles || numDataCacheFiles > maxNumCacheFiles) {
+ LOG(ERROR)
+ << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files: "
+ "numModelCacheFiles = "
+ << numModelCacheFiles << ", numDataCacheFiles = " << numDataCacheFiles
+ << ", maxNumCacheFiles = " << maxNumCacheFiles;
+ return false;
+ }
+
+ // set internal members
+ mCapabilities = std::move(capabilities);
+ mSupportedExtensions = supportedExtensions;
+ mType = type;
+ mVersionString = versionString;
+ mNumberOfCacheFilesNeeded = {numModelCacheFiles, numDataCacheFiles};
+ return true;
+}
+
std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
- // verify input
CHECK(device != nullptr) << "VersionedIDevice::Core::create passed invalid device object.";
// create death handler object
@@ -480,7 +547,7 @@ Return<T_Return> VersionedIDevice::recoverable(
return ret;
}
-std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() const {
+std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilitiesInternal() const {
const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
std::pair<ErrorStatus, Capabilities> result;
@@ -559,7 +626,12 @@ std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() const {
return {ErrorStatus::DEVICE_UNAVAILABLE, {}};
}
-std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() const {
+const Capabilities& VersionedIDevice::getCapabilities() const {
+ return mCapabilities;
+}
+
+std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensionsInternal()
+ const {
const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
// version 1.2+ HAL
@@ -590,6 +662,10 @@ std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtens
return {ErrorStatus::DEVICE_UNAVAILABLE, {}};
}
+const std::vector<Extension>& VersionedIDevice::getSupportedExtensions() const {
+ return mSupportedExtensions;
+}
+
std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
const MetaModel& metaModel) const {
const std::pair<ErrorStatus, hidl_vec<bool>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
@@ -967,7 +1043,7 @@ int64_t VersionedIDevice::getFeatureLevel() const {
}
}
-int32_t VersionedIDevice::getType() const {
+int32_t VersionedIDevice::getTypeInternal() const {
constexpr int32_t kFailure = -1;
// version 1.2+ HAL
@@ -988,12 +1064,22 @@ int32_t VersionedIDevice::getType() const {
return result;
}
- // version too low or no device available
- LOG(INFO) << "Unknown NNAPI device type.";
- return ANEURALNETWORKS_DEVICE_UNKNOWN;
+ // version too low
+ if (getDevice<V1_0::IDevice>() != nullptr) {
+ LOG(INFO) << "Unknown NNAPI device type.";
+ return ANEURALNETWORKS_DEVICE_UNKNOWN;
+ }
+
+ // No device available
+ LOG(ERROR) << "Could not handle getType";
+ return kFailure;
+}
+
+int32_t VersionedIDevice::getType() const {
+ return mType;
}
-std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() const {
+std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionStringInternal() const {
const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""};
// version 1.2+ HAL
@@ -1023,7 +1109,12 @@ std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() const {
return kFailure;
}
-std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const {
+const std::string& VersionedIDevice::getVersionString() const {
+ return mVersionString;
+}
+
+std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeededInternal()
+ const {
constexpr std::tuple<ErrorStatus, uint32_t, uint32_t> kFailure = {ErrorStatus::GENERAL_FAILURE,
0, 0};
@@ -1055,5 +1146,13 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi
return kFailure;
}
+std::pair<uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const {
+ return mNumberOfCacheFilesNeeded;
+}
+
+const std::string& VersionedIDevice::getName() const {
+ return mServiceName;
+}
+
} // namespace nn
} // namespace android
diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h
index 87e776507..7f6f11af3 100644
--- a/nn/runtime/VersionedInterfaces.h
+++ b/nn/runtime/VersionedInterfaces.h
@@ -94,13 +94,9 @@ class VersionedIDevice {
/**
* Gets the capabilities of a driver.
*
- * @return status Error status of the call, must be:
- * - NONE if successful
- * - DEVICE_UNAVAILABLE if driver is offline or busy
- * - GENERAL_FAILURE if there is an unspecified error
* @return capabilities Capabilities of the driver.
*/
- std::pair<hal::ErrorStatus, hal::Capabilities> getCapabilities() const;
+ const hal::Capabilities& getCapabilities() const;
/**
* Gets information about extensions supported by the driver implementation.
@@ -111,13 +107,9 @@ class VersionedIDevice {
* All extension operations and operands must be fully supported for the
* extension to appear in the list of supported extensions.
*
- * @return status Error status of the call, must be:
- * - NONE if successful
- * - DEVICE_UNAVAILABLE if driver is offline or busy
- * - GENERAL_FAILURE if there is an unspecified error
* @return extensions A list of supported extensions.
*/
- std::pair<hal::ErrorStatus, hal::hidl_vec<hal::Extension>> getSupportedExtensions() const;
+ const std::vector<hal::Extension>& getSupportedExtensions() const;
/**
* Gets the supported operations in a MetaModel.
@@ -340,14 +332,12 @@ class VersionedIDevice {
/**
* Returns the device type of a driver.
*
- * @return deviceType The type of a given device, which can help application developers
- * developers to distribute Machine Learning workloads and other workloads
- * such as graphical rendering. E.g., for an app which renders AR scenes
- * based on real time object detection results, the developer could choose
- * an ACCELERATOR type device for ML workloads, and reserve GPU for
- * graphical rendering.
- * Return -1 if the driver is offline or busy, or the query resulted in
- * an unspecified error.
+ * @return deviceType The type of a given device, which can help application
+ * developers to distribute Machine Learning workloads and other
+ * workloads such as graphical rendering. E.g., for an app which renders
+ * AR scenes based on real time object detection results, the developer
+ * could choose an ACCELERATOR type device for ML workloads, and reserve
+ * GPU for graphical rendering.
*/
int32_t getType() const;
@@ -371,15 +361,9 @@ class VersionedIDevice {
* the driver cannot meet that requirement because of bugs or certain optimizations.
* The application can filter out versions of these drivers.
*
- * @return status Error status returned from querying the version string. Must be:
- * - NONE if the query was successful
- * - DEVICE_UNAVAILABLE if driver is offline or busy
- * - GENERAL_FAILURE if the query resulted in an
- * unspecified error
* @return version The version string of the device implementation.
- * Must have nonzero length if the query is successful, and must be an empty string if not.
*/
- std::pair<hal::ErrorStatus, hal::hidl_string> getVersionString() const;
+ const std::string& getVersionString() const;
/**
* Gets the caching requirements of the driver implementation.
@@ -408,10 +392,6 @@ class VersionedIDevice {
* IDevice::prepareModelFromCache or providing cache file descriptors to
* IDevice::prepareModel_1_2.
*
- * @return status Error status of the call, must be:
- * - NONE if successful
- * - DEVICE_UNAVAILABLE if driver is offline or busy
- * - GENERAL_FAILURE if there is an unspecified error
* @return numModelCache An unsigned integer indicating how many files for model cache
* the driver needs to cache a single prepared model. It must
* be less than or equal to Constant::MAX_NUMBER_OF_CACHE_FILES.
@@ -419,9 +399,35 @@ class VersionedIDevice {
* the driver needs to cache a single prepared model. It must
* be less than or equal to Constant::MAX_NUMBER_OF_CACHE_FILES.
*/
- std::tuple<hal::ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const;
+ std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const;
+
+ /**
+ * Returns the name of the service.
+ *
+ * @return Name of the service.
+ */
+ const std::string& getName() const;
private:
+ // initializeInternal is called once during VersionedIDevice creation.
+ // 'true' indicates successful initialization.
+ bool initializeInternal();
+
+ // internal helper methods
+ std::pair<hal::ErrorStatus, hal::Capabilities> getCapabilitiesInternal() const;
+ std::pair<hal::ErrorStatus, hal::hidl_vec<hal::Extension>> getSupportedExtensionsInternal()
+ const;
+ int32_t getTypeInternal() const;
+ std::pair<hal::ErrorStatus, hal::hidl_string> getVersionStringInternal() const;
+ std::tuple<hal::ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeededInternal() const;
+
+ // internal members for the cached results of the internal methods above
+ hal::Capabilities mCapabilities;
+ std::vector<hal::Extension> mSupportedExtensions;
+ int32_t mType;
+ std::string mVersionString;
+ std::pair<uint32_t, uint32_t> mNumberOfCacheFilesNeeded;
+
/**
* This is a utility class for VersionedIDevice that encapsulates a
* V1_0::IDevice, any appropriate downcasts to newer interfaces, and a
diff --git a/nn/runtime/test/TestCompilationCaching.cpp b/nn/runtime/test/TestCompilationCaching.cpp
index c061e853d..bce2836d2 100644
--- a/nn/runtime/test/TestCompilationCaching.cpp
+++ b/nn/runtime/test/TestCompilationCaching.cpp
@@ -14,12 +14,14 @@
* limitations under the License.
*/
+#include <android-base/scopeguard.h>
#include <gtest/gtest.h>
#include <cstdlib>
#include <filesystem>
#include <numeric>
#include <string>
+#include <string_view>
#include <tuple>
#include <vector>
@@ -48,13 +50,32 @@ namespace {
enum class HasCalledPrepareModel { NO, WITHOUT_CACHING, WITH_CACHING };
+// Print HasCalledPrepareModel enum for better GTEST failure messages
+std::ostream& operator<<(std::ostream& os, HasCalledPrepareModel hasCalledPrepareModel) {
+ switch (hasCalledPrepareModel) {
+ case HasCalledPrepareModel::NO:
+ return os << "NO";
+ case HasCalledPrepareModel::WITHOUT_CACHING:
+ return os << "WITHOUT_CACHING";
+ case HasCalledPrepareModel::WITH_CACHING:
+ return os << "WITH_CACHING";
+ }
+ CHECK(false) << "HasCalledPrepareModel print called with invalid code "
+ << static_cast<int>(hasCalledPrepareModel);
+ return os;
+}
+
+// Whether the driver is expected to be registered because it can pass initialization.
+bool canDeviceBeRegistered(ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
+ constexpr uint32_t maxNumCacheFiles =
+ static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES);
+ return error == ErrorStatus::NONE && numModelCache <= maxNumCacheFiles &&
+ numDataCache <= maxNumCacheFiles;
+}
+
// Whether the driver supports caching based on the returns from getNumberOfCacheFilesNeeded.
-bool isCachingSupportedAndNoError(ErrorStatus error, uint32_t numModelCache,
- uint32_t numDataCache) {
- return error == ErrorStatus::NONE &&
- numModelCache <= static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) &&
- numDataCache <= static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) &&
- (numModelCache != 0 || numDataCache != 0);
+bool isCachingSupported(uint32_t numModelCache, uint32_t numDataCache) {
+ return numModelCache != 0 || numDataCache != 0;
}
// This is an IDevice for testing purposes which overrides several methods from sample driver:
@@ -99,9 +120,10 @@ class CachingDriver : public sample_driver::SampleDriver {
};
public:
- CachingDriver(const char* name, ErrorStatus errorStatusGetNumCacheFiles, uint32_t numModelCache,
- uint32_t numDataCache, ErrorStatus errorStatusPrepareFromCache)
- : SampleDriver(name),
+ CachingDriver(std::string_view name, ErrorStatus errorStatusGetNumCacheFiles,
+ uint32_t numModelCache, uint32_t numDataCache,
+ ErrorStatus errorStatusPrepareFromCache)
+ : SampleDriver(name.data()),
mErrorStatusGetNumCacheFiles(errorStatusGetNumCacheFiles),
mNumModelCache(numModelCache),
mNumDataCache(numDataCache),
@@ -181,8 +203,7 @@ class CachingDriver : public sample_driver::SampleDriver {
private:
// Checks the number of cache files passed to the driver from runtime.
void checkNumberOfCacheHandles(size_t modelCache, size_t dataCache) {
- if (isCachingSupportedAndNoError(mErrorStatusGetNumCacheFiles, mNumModelCache,
- mNumDataCache)) {
+ if (isCachingSupported(mNumModelCache, mNumDataCache)) {
if (modelCache != 0 || dataCache != 0) {
ASSERT_EQ(modelCache, mNumModelCache);
ASSERT_EQ(dataCache, mNumDataCache);
@@ -239,12 +260,69 @@ void CreateBroadcastAddModel(test_wrapper::Model* model) {
ASSERT_EQ(model->finish(), Result::NO_ERROR);
}
-// Test model compilation with a driver parameterized with
+void getDeviceWithName(std::string_view deviceName, const ANeuralNetworksDevice** outputDevice) {
+ uint32_t numDevices = 0;
+ ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
+ EXPECT_GE(numDevices, (uint32_t)1);
+
+ int numMatchingDevices = 0;
+ for (uint32_t i = 0; i < numDevices; i++) {
+ ANeuralNetworksDevice* device = nullptr;
+ ASSERT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
+
+ const char* buffer = nullptr;
+ ASSERT_EQ(ANeuralNetworksDevice_getName(device, &buffer), ANEURALNETWORKS_NO_ERROR);
+ if (deviceName == buffer) {
+ *outputDevice = device;
+ numMatchingDevices++;
+ }
+ }
+
+ EXPECT_LE(numMatchingDevices, 1);
+}
+
+// Test device registration with a driver parameterized with
// - ErrorStatus returning from getNumberOfCacheFilesNeeded
// - Number of model cache files returning from getNumberOfCacheFilesNeeded
// - Number of data cache files returning from getNumberOfCacheFilesNeeded
+using DeviceRegistrationTestParam = std::tuple<ErrorStatus, uint32_t, uint32_t>;
+
+class DeviceRegistrationTest : public ::testing::TestWithParam<DeviceRegistrationTestParam> {
+ protected:
+ static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
+ const ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam());
+ const uint32_t kNumModelCache = std::get<1>(GetParam());
+ const uint32_t kNumDataCache = std::get<2>(GetParam());
+ const sp<CachingDriver> kDriver =
+ new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
+ kNumDataCache, ErrorStatus::NONE);
+};
+
+TEST_P(DeviceRegistrationTest, CachingFailure) {
+ if (DeviceManager::get()->getUseCpuOnly()) {
+ return;
+ }
+
+ DeviceManager::get()->forTest_registerDevice(kDeviceName.data(), kDriver);
+ const auto cleanup = android::base::make_scope_guard(
+ [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
+
+ // get device
+ const ANeuralNetworksDevice* device = nullptr;
+ getDeviceWithName(kDeviceName, &device);
+
+ // check if device registeration matches expectations
+ const bool isDeviceRegistered = (device != nullptr);
+ const bool expectDeviceToBeRegistered =
+ canDeviceBeRegistered(kErrorStatusGetNumCacheFiles, kNumModelCache, kNumDataCache);
+ ASSERT_EQ(isDeviceRegistered, expectDeviceToBeRegistered);
+}
+
+// Test model compilation with a driver parameterized with
+// - Number of model cache files returning from getNumberOfCacheFilesNeeded
+// - Number of data cache files returning from getNumberOfCacheFilesNeeded
// - ErrorStatus returning from prepareModelFromCache
-using CompilationCachingTestParam = std::tuple<ErrorStatus, uint32_t, uint32_t, ErrorStatus>;
+using CompilationCachingTestParam = std::tuple<uint32_t, uint32_t, ErrorStatus>;
class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachingTestParam> {
protected:
@@ -254,9 +332,6 @@ class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachin
ASSERT_NE(cacheDir, nullptr);
mCacheDir = cacheDir;
CreateBroadcastAddModel(&mModel);
- mToken = std::vector<uint8_t>(ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN, 0);
- mIsCachingSupportedAndNoError = isCachingSupportedAndNoError(kErrorStatusGetNumCacheFiles,
- kNumModelCache, kNumDataCache);
}
virtual void TearDown() override {
@@ -266,38 +341,26 @@ class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachin
}
void compileModel(const sp<CachingDriver>& driver, bool withToken) {
- DeviceManager::get()->forTest_registerDevice(kDeviceName, driver);
-
- // Make device list including only a single driver device.
- uint32_t numDevices = 0;
- EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
- EXPECT_GE(numDevices, (uint32_t)1);
- std::vector<ANeuralNetworksDevice*> devices;
- for (uint32_t i = 0; i < numDevices; i++) {
- ANeuralNetworksDevice* device = nullptr;
- EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
- const char* buffer = nullptr;
- int result = ANeuralNetworksDevice_getName(device, &buffer);
- if (result == ANEURALNETWORKS_NO_ERROR && strcmp(buffer, kDeviceName) == 0) {
- devices.push_back(device);
- break;
- }
- }
- ASSERT_EQ(devices.size(), 1u);
+ DeviceManager::get()->forTest_registerDevice(kDeviceName.data(), driver);
+ const auto cleanup = android::base::make_scope_guard(
+ [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
+
+ // Get a handle to the single driver device matching kDeviceName.
+ const ANeuralNetworksDevice* device = nullptr;
+ getDeviceWithName(kDeviceName, &device);
+ ASSERT_NE(device, nullptr);
// Compile the model with the device.
ANeuralNetworksCompilation* compilation = nullptr;
- ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), devices.data(),
- devices.size(), &compilation),
+ ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), &device, 1,
+ &compilation),
ANEURALNETWORKS_NO_ERROR);
if (withToken) {
ASSERT_EQ(ANeuralNetworksCompilation_setCaching(compilation, mCacheDir.c_str(),
- mToken.data()),
+ kToken.data()),
ANEURALNETWORKS_NO_ERROR);
}
ASSERT_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR);
-
- DeviceManager::get()->forTest_reInitializeDeviceList();
}
void createCache() {
@@ -306,31 +369,29 @@ class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachin
compileModel(driver, /*withToken=*/true);
}
- static constexpr char kDeviceName[] = "deviceTestCompilationCaching";
- const ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam());
- const uint32_t kNumModelCache = std::get<1>(GetParam());
- const uint32_t kNumDataCache = std::get<2>(GetParam());
- const ErrorStatus kErrorStatusPrepareFromCache = std::get<3>(GetParam());
- bool mIsCachingSupportedAndNoError;
+ static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
+ const uint32_t kNumModelCache = std::get<0>(GetParam());
+ const uint32_t kNumDataCache = std::get<1>(GetParam());
+ const ErrorStatus kErrorStatusPrepareFromCache = std::get<2>(GetParam());
+ const bool kIsCachingSupported = isCachingSupported(kNumModelCache, kNumDataCache);
test_wrapper::Model mModel;
std::string mCacheDir;
- std::vector<uint8_t> mToken;
+ const CacheToken kToken{};
};
TEST_P(CompilationCachingTest, TokenProvidedAndCacheNotExist) {
if (DeviceManager::get()->getUseCpuOnly()) {
return;
}
- sp<CachingDriver> driver =
- new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
- kNumDataCache, kErrorStatusPrepareFromCache);
+ sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
+ kNumDataCache, kErrorStatusPrepareFromCache);
compileModel(driver, /*withToken=*/true);
// When cache file does not exist, the runtime should never call prepareModelFromCache.
- EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), false);
+ EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
// The runtime should call prepareModel_1_2. It should request caching iff caching supported.
- EXPECT_EQ(driver->hasCalledPrepareModel(), mIsCachingSupportedAndNoError
+ EXPECT_EQ(driver->hasCalledPrepareModel(), kIsCachingSupported
? HasCalledPrepareModel::WITH_CACHING
: HasCalledPrepareModel::WITHOUT_CACHING);
}
@@ -340,16 +401,15 @@ TEST_P(CompilationCachingTest, TokenProvidedAndCacheExist) {
return;
}
createCache();
- sp<CachingDriver> driver =
- new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
- kNumDataCache, kErrorStatusPrepareFromCache);
+ sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
+ kNumDataCache, kErrorStatusPrepareFromCache);
compileModel(driver, /*withToken=*/true);
// When cache files exist, the runtime should call prepareModelFromCache iff caching supported.
- EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), mIsCachingSupportedAndNoError);
+ EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), kIsCachingSupported);
HasCalledPrepareModel expectHasCalledPrepareModel;
- if (mIsCachingSupportedAndNoError) {
+ if (kIsCachingSupported) {
if (kErrorStatusPrepareFromCache == ErrorStatus::NONE) {
// The runtime should not call prepareModel_1_2 iff caching supported and
// prepareModelFromCache succeeds.
@@ -370,14 +430,13 @@ TEST_P(CompilationCachingTest, TokenNotProvided) {
if (DeviceManager::get()->getUseCpuOnly()) {
return;
}
- sp<CachingDriver> driver =
- new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
- kNumDataCache, kErrorStatusPrepareFromCache);
+ sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
+ kNumDataCache, kErrorStatusPrepareFromCache);
compileModel(driver, /*withToken=*/false);
// When no NDK token is provided by the client, the runtime should never call
// prepareModelFromCache or request caching with prepareModel_1_2.
- EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), false);
+ EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
EXPECT_EQ(driver->hasCalledPrepareModel(), HasCalledPrepareModel::WITHOUT_CACHING);
}
@@ -386,12 +445,18 @@ static const auto kErrorStatusGetNumCacheFilesChoices =
static const auto kNumCacheChoices =
testing::Values(0ul, 1ul, static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES),
static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) + 1);
+static const auto kNumValidCacheChoices =
+ testing::Values(0ul, 1ul, static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES));
static const auto kErrorStatusPrepareFromCacheChoices =
testing::Values(ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE,
ErrorStatus::DEVICE_UNAVAILABLE, ErrorStatus::INVALID_ARGUMENT);
-INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest,
+INSTANTIATE_TEST_CASE_P(TestCompilationCaching, DeviceRegistrationTest,
testing::Combine(kErrorStatusGetNumCacheFilesChoices, kNumCacheChoices,
- kNumCacheChoices, kErrorStatusPrepareFromCacheChoices));
+ kNumCacheChoices));
+
+INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest,
+ testing::Combine(kNumValidCacheChoices, kNumValidCacheChoices,
+ kErrorStatusPrepareFromCacheChoices));
} // namespace
diff --git a/nn/runtime/test/TestIntrospectionControl.cpp b/nn/runtime/test/TestIntrospectionControl.cpp
index 9d0cbe6c3..76a6f5ee2 100644
--- a/nn/runtime/test/TestIntrospectionControl.cpp
+++ b/nn/runtime/test/TestIntrospectionControl.cpp
@@ -229,9 +229,9 @@ TEST_F(IntrospectionControlTest, SimpleAddModel) {
// Verify that the mCompilation is actually using the "test-all" device.
CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(mCompilation);
- const char* deviceNameBuffer =
+ const std::string& deviceNameBuffer =
c->forTest_getExecutionPlan().forTest_simpleGetDevice()->getName();
- EXPECT_TRUE(driverName.compare(deviceNameBuffer) == 0);
+ EXPECT_EQ(driverName, deviceNameBuffer);
float input1[2] = {1.0f, 2.0f};
float input2[2] = {3.0f, 4.0f};
@@ -655,7 +655,7 @@ TEST_P(TimingTest, Test) {
switch (kDriverKind) {
case DriverKind::CPU: {
// There should be only one driver -- the CPU
- const char* name = DeviceManager::get()->getDrivers()[0]->getName();
+ const std::string& name = DeviceManager::get()->getDrivers()[0]->getName();
ASSERT_TRUE(selectDeviceByName(name));
break;
}
diff --git a/nn/runtime/test/TestPartitioning.cpp b/nn/runtime/test/TestPartitioning.cpp
index b01e58015..2c774a6e8 100644
--- a/nn/runtime/test/TestPartitioning.cpp
+++ b/nn/runtime/test/TestPartitioning.cpp
@@ -1254,7 +1254,7 @@ TEST_F(PartitioningTest, SimpleModel) {
ANEURALNETWORKS_NO_ERROR);
ASSERT_EQ(planA.forTest_getKind(), ExecutionPlan::Kind::SIMPLE);
ASSERT_NE(planA.forTest_simpleGetDevice().get(), nullptr);
- ASSERT_STREQ(planA.forTest_simpleGetDevice()->getName(), "good");
+ ASSERT_EQ(planA.forTest_simpleGetDevice()->getName(), "good");
// Simple partition (two devices are each capable of everything, none better than CPU).
// No need to compare the original model to the model from the plan -- we
@@ -1342,7 +1342,7 @@ TEST_F(PartitioningTest, SliceModel) {
ANEURALNETWORKS_NO_ERROR);
ASSERT_EQ(planA.forTest_getKind(), ExecutionPlan::Kind::SIMPLE);
ASSERT_NE(planA.forTest_simpleGetDevice().get(), nullptr);
- ASSERT_STREQ(planA.forTest_simpleGetDevice()->getName(), "V1_2");
+ ASSERT_EQ(planA.forTest_simpleGetDevice()->getName(), "V1_2");
// Compound partition (V1_0, V1_1, V1_2 devices are available, in decreasing
// order of performance; model is distributed across all three devices).
@@ -1442,7 +1442,7 @@ TEST_F(PartitioningTest, SliceModelToEmpty) {
ANEURALNETWORKS_NO_ERROR);
ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE);
ASSERT_NE(plan.forTest_simpleGetDevice().get(), nullptr);
- ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), "V1_2");
+ ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), "V1_2");
}
TEST_F(PartitioningTest, Cpu) {
@@ -1682,7 +1682,7 @@ TEST_F(PartitioningTest, OemOperations) {
const auto& planBestOEM = compilationBestOEM.getExecutionPlan();
ASSERT_EQ(planBestOEM.forTest_getKind(), ExecutionPlan::Kind::SIMPLE);
ASSERT_NE(planBestOEM.forTest_simpleGetDevice().get(), nullptr);
- ASSERT_STREQ(planBestOEM.forTest_simpleGetDevice()->getName(), "goodOEM");
+ ASSERT_EQ(planBestOEM.forTest_simpleGetDevice()->getName(), "goodOEM");
// Verify that we get an error if no driver can run an OEM operation.
const auto devicesNoOEM = makeDevices({{"noOEM", 0.5, ~0U, PartitioningDriver::OEMNo}});
@@ -1724,7 +1724,7 @@ TEST_F(PartitioningTest, RelaxedFP) {
ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan),
ANEURALNETWORKS_NO_ERROR);
ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE);
- ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), expectDevice);
+ ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), expectDevice);
};
ASSERT_NO_FATAL_FAILURE(TrivialTest(false, "f32"));
@@ -1772,7 +1772,7 @@ TEST_F(PartitioningTest, Perf) {
ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan),
ANEURALNETWORKS_NO_ERROR);
ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE);
- ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), "good");
+ ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), "good");
}
{
@@ -1790,7 +1790,7 @@ TEST_F(PartitioningTest, Perf) {
ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan),
ANEURALNETWORKS_NO_ERROR);
ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE);
- ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), "base");
+ ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), "base");
}
};
@@ -1853,7 +1853,7 @@ class CacheTest : public PartitioningTest {
// Find the cache info for the device.
const uint8_t* token = nullptr;
if (plan.forTest_getKind() == ExecutionPlan::Kind::SIMPLE) {
- ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), deviceName);
+ ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), deviceName);
token = plan.forTest_simpleGetCacheToken();
} else if (plan.forTest_getKind() == ExecutionPlan::Kind::COMPOUND) {
const auto& steps = plan.forTest_compoundGetSteps();
@@ -1861,7 +1861,7 @@ class CacheTest : public PartitioningTest {
for (const auto& step : steps) {
// In general, two or more partitions can be on the same device. However, this will
// not happen on the test models with only 2 operations.
- if (strcmp(step->getDevice()->getName(), deviceName) == 0) {
+ if (step->getDevice()->getName() == deviceName) {
ASSERT_FALSE(found);
token = step->forTest_getCacheToken();
found = true;