diff options
author | Michael Butler <butlermichael@google.com> | 2019-10-29 11:51:41 -0700 |
---|---|---|
committer | android-build-merger <android-build-merger@google.com> | 2019-10-29 11:51:41 -0700 |
commit | b34fce3c2a57b2e5148900f9699340ec63f671b6 (patch) | |
tree | 27d70e863b09ccb2471a08e63c835d2777c5f3a6 | |
parent | bd86742ff4244cd721652c4b52ee5a43426cac82 (diff) | |
parent | 9b84d74900feaa7cd32bd8dff90724f1dee8120b (diff) | |
download | ml-b34fce3c2a57b2e5148900f9699340ec63f671b6.tar.gz |
Merge changes I032a5013,Ib1a95b75,I40584542,I026bad69,Icf98281b, ... am: 9c1d0b0442 am: cb060fcfee
am: 9b84d74900
Change-Id: I323d0920ec580027f71f7363da3da9f1e98c1f73
-rw-r--r-- | nn/runtime/BurstBuilder.h | 8 | ||||
-rw-r--r-- | nn/runtime/ExecutionBuilder.cpp | 208 | ||||
-rw-r--r-- | nn/runtime/ExecutionBuilder.h | 10 | ||||
-rw-r--r-- | nn/runtime/ExecutionPlan.cpp | 9 | ||||
-rw-r--r-- | nn/runtime/Manager.cpp | 216 | ||||
-rw-r--r-- | nn/runtime/Manager.h | 13 | ||||
-rw-r--r-- | nn/runtime/NeuralNetworks.cpp | 23 | ||||
-rw-r--r-- | nn/runtime/TypeManager.cpp | 2 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.cpp | 121 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.h | 66 | ||||
-rw-r--r-- | nn/runtime/test/TestCompilationCaching.cpp | 189 | ||||
-rw-r--r-- | nn/runtime/test/TestIntrospectionControl.cpp | 6 | ||||
-rw-r--r-- | nn/runtime/test/TestPartitioning.cpp | 18 |
13 files changed, 477 insertions, 412 deletions
diff --git a/nn/runtime/BurstBuilder.h b/nn/runtime/BurstBuilder.h index bfb9a9a38..6a3ba783e 100644 --- a/nn/runtime/BurstBuilder.h +++ b/nn/runtime/BurstBuilder.h @@ -31,10 +31,10 @@ class CompilationBuilder; * TODO: Could we "hide" the per-step burst controller instance inside * StepExecutor? Today it's exposed as a "sibling" to StepExecutor: * ExecutionPlan::next both generates a StepExecutor instance and finds a - * pointer to a burst controller; and StepExecutor::startCompute is passed a - * pointer to a burst controller. Instead, could ExecutionPlan::next stash the - * burst controller in the StepExecutor, so that it doesn't have to be passed - * to any of the StepExecutor methods? + * pointer to a burst controller; and StepExecutor::compute is passed a pointer + * to a burst controller. Instead, could ExecutionPlan::next stash the burst + * controller in the StepExecutor, so that it doesn't have to be passed to any + * of the StepExecutor methods? */ class BurstBuilder { diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp index 97e847b24..9e73118e9 100644 --- a/nn/runtime/ExecutionBuilder.cpp +++ b/nn/runtime/ExecutionBuilder.cpp @@ -303,12 +303,9 @@ int ExecutionBuilder::getOutputOperandRank(uint32_t index, uint32_t* rank) { // For Q this is irrelevant: We only support timing in conjunction // with an explicit device list; and we do not support CPU fallback // with an explicit device list. See CompilationBuilder::mExplicitDeviceList. -static int cpuFallbackFull(ExecutionBuilder* executionBuilder, - sp<ExecutionCallback>* fallbackCallback) { +static std::tuple<int, std::vector<OutputShape>, Timing> cpuFallbackFull( + ExecutionBuilder* executionBuilder) { CHECK(executionBuilder != nullptr); - CHECK(fallbackCallback != nullptr); - *fallbackCallback = nullptr; - NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "cpuFallbackFull"); VLOG(EXECUTION) << "cpuFallbackFull"; @@ -318,58 +315,45 @@ static int cpuFallbackFull(ExecutionBuilder* executionBuilder, executor.mapInputsAndOutputsTrivially(); // Attempt fallback execution. - NN_RETURN_IF_ERROR(executor.startComputeOnCpuFallback(fallbackCallback)); - CHECK(*fallbackCallback != nullptr); - (*fallbackCallback)->wait(); - return ANEURALNETWORKS_NO_ERROR; + return executor.computeOnCpuFallback(); } // Attempt synchronous execution on CPU. -// fallbackExecutor is non-null i.f.f. ANEURALNETWORKS_NO_ERROR is returned. -// fallbackCallback is non-null i.f.f. ANEURALNETWORKS_NO_ERROR is returned. // TODO: How should we handle timing in this case? // For Q this is irrelevant: We only support timing in conjunction // with an explicit device list; and we do not support CPU fallback // with an explicit device list. See CompilationBuilder::mExplicitDeviceList. -static int cpuFallbackPartial(const ExecutionPlan* plan, - std::shared_ptr<ExecutionPlan::Controller> controller, - std::shared_ptr<StepExecutor>* fallbackExecutor, - sp<ExecutionCallback>* fallbackCallback) { - CHECK(plan != nullptr); - CHECK(fallbackExecutor != nullptr); - *fallbackExecutor = nullptr; - CHECK(fallbackCallback != nullptr); - *fallbackCallback = nullptr; - +static std::tuple<int, std::vector<OutputShape>, Timing, std::shared_ptr<StepExecutor>> +cpuFallbackPartial(const ExecutionPlan& plan, + std::shared_ptr<ExecutionPlan::Controller> controller) { NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "cpuFallbackPartial"); VLOG(EXECUTION) << "cpuFallbackPartial"; // Get fallback executor. std::shared_ptr<StepExecutor> executor; - NN_RETURN_IF_ERROR(plan->fallback(controller, &executor)); + int n1 = plan.fallback(controller, &executor); + if (n1 != ANEURALNETWORKS_NO_ERROR) { + return {n1, {}, kNoTiming, nullptr}; + } CHECK(executor != nullptr); // Attempt fallback execution. - NN_RETURN_IF_ERROR(executor->startComputeOnCpuFallback(fallbackCallback)); - CHECK(*fallbackCallback != nullptr); - (*fallbackCallback)->wait(); - *fallbackExecutor = executor; - return ANEURALNETWORKS_NO_ERROR; + auto [n2, outputShapes, timing] = executor->computeOnCpuFallback(); + return {n2, std::move(outputShapes), timing, executor}; } static void asyncStartComputePartitioned(ExecutionBuilder* executionBuilder, - const ExecutionPlan* plan, + const ExecutionPlan& plan, std::shared_ptr<ExecutionPlan::Controller> controller, bool allowFallback, const sp<ExecutionCallback>& executionCallback) { CHECK(executionBuilder != nullptr); - CHECK(plan != nullptr); VLOG(EXECUTION) << "ExecutionBuilder::compute (from plan, iteratively)"; - std::vector<OutputShape> outputShapes; + + std::vector<OutputShape> outputShapes = executionBuilder->getInitialOutputShapes(); Timing timing = kNoTiming; // Disallow fallback when the ExecutionPlan is simple on CPU. - allowFallback &= !plan->isSimpleCpu(); - executionBuilder->initializeOutputShapes(&outputShapes); + allowFallback &= !plan.isSimpleCpu(); while (true) { VLOG(EXECUTION) << "looking for next StepExecutor"; @@ -377,7 +361,7 @@ static void asyncStartComputePartitioned(ExecutionBuilder* executionBuilder, // Get the current step of the execution. std::shared_ptr<StepExecutor> executor; std::shared_ptr<ExecutionBurstController> burstController; - int n = plan->next(controller, &executor, &burstController); + int n = plan.next(controller, &executor, &burstController); if (n != ANEURALNETWORKS_NO_ERROR) { if (allowFallback) break; executionCallback->notify(convertResultCodeToErrorStatus(n), {}, kNoTiming); @@ -393,99 +377,75 @@ static void asyncStartComputePartitioned(ExecutionBuilder* executionBuilder, const bool executorIsCpu = executor->isCpu(); // Attempt to execute a single step of the execution. - sp<ExecutionCallback> stepCallback; - n = executor->startCompute(&stepCallback, burstController); + auto [stepN, stepOutputShapes, stepTiming] = executor->compute(burstController); - // Immediately end execution if there was an error and fallback is not - // allowed. - if (n != ANEURALNETWORKS_NO_ERROR && !allowFallback) { - executionCallback->notify(convertResultCodeToErrorStatus(n), {}, kNoTiming); - return; + // Update global outputs. + if (!executor->updateOutputShapes(stepOutputShapes, &outputShapes)) { + stepN = ANEURALNETWORKS_OP_FAILED; } - // If execution successfully launched, process the execution. - if (n == ANEURALNETWORKS_NO_ERROR) { - stepCallback->wait(); - ErrorStatus status = stepCallback->getStatus(); - const auto& stepOutputShapes = stepCallback->getOutputShapes(); - - // Update global outputs. - if (!executor->updateOutputShapes(stepOutputShapes, &outputShapes)) { - status = ErrorStatus::GENERAL_FAILURE; - } - - // If execution was successful, continue to next step. - if (status == ErrorStatus::NONE) { - // We only support collection of timing information in the case of a - // single step, so it's safe to just keep track of the last step's - // timing information. - timing = stepCallback->getTiming(); - continue; - } - - // OUTPUT_INSUFFICIENT_SIZE is not recoverable, so end execution. - if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) { - executionCallback->notify(status, outputShapes, kNoTiming); - return; - } + // If execution was successful, continue to next step. + if (stepN == ANEURALNETWORKS_NO_ERROR) { + // We only support collection of timing information in the case of a + // single step, so it's safe to just keep track of the last step's + // timing information. + timing = stepTiming; + continue; + } - // If fallback is not allowed and there was an error, end execution. - if (!allowFallback) { - executionCallback->notify(status, {}, kNoTiming); - return; - } + // OUTPUT_INSUFFICIENT_SIZE is not recoverable, so end execution. + if (stepN == ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE) { + const ErrorStatus stepStatus = convertResultCodeToErrorStatus(stepN); + executionCallback->notify(stepStatus, outputShapes, kNoTiming); + return; + } - // Propagate error to fallback path. - n = convertErrorStatusToResultCode(status); + // If fallback is not allowed and there was an error, end execution. + if (!allowFallback) { + const ErrorStatus stepStatus = convertResultCodeToErrorStatus(stepN); + executionCallback->notify(stepStatus, {}, kNoTiming); + return; } // If CPU execution was already attempted, either: // (1) perform a full fallback if the plan is not simple, or // (2) return from the function with an error if (executorIsCpu) { - if (!plan->isSimple()) break; - executionCallback->notify(convertResultCodeToErrorStatus(n), {}, kNoTiming); + if (!plan.isSimple()) break; + executionCallback->notify(convertResultCodeToErrorStatus(stepN), {}, kNoTiming); return; } // If the code reaches this point, attempt a partial fallback to CPU. CHECK(allowFallback); - std::shared_ptr<StepExecutor> fallbackExecutor; - sp<ExecutionCallback> fallbackCallback; - n = cpuFallbackPartial(plan, controller, &fallbackExecutor, &fallbackCallback); - - // Immediately fallback to full CPU execution if there was an error with - // the partial CPU fallback. - if (n != ANEURALNETWORKS_NO_ERROR) { - break; - } - - // Get fallback execution results. - ErrorStatus fallbackStatus = fallbackCallback->getStatus(); - const auto& fallbackOutputShapes = fallbackCallback->getOutputShapes(); + auto [fallbackN, fallbackOutputShapes, fallbackTiming, fallbackExecutor] = + cpuFallbackPartial(plan, controller); // Update global outputs. - if (!fallbackExecutor->updateOutputShapes(fallbackOutputShapes, &outputShapes)) { - fallbackStatus = ErrorStatus::GENERAL_FAILURE; + if (fallbackExecutor != nullptr && + !fallbackExecutor->updateOutputShapes(fallbackOutputShapes, &outputShapes)) { + fallbackN = ANEURALNETWORKS_OP_FAILED; } // If execution was successful, continue to next step. - if (fallbackStatus == ErrorStatus::NONE) { + if (fallbackN == ANEURALNETWORKS_NO_ERROR) { // We only support collection of timing information in the case of a // single step, so it's safe to just keep track of the last step's // timing information. - timing = fallbackCallback->getTiming(); + timing = fallbackTiming; continue; } // OUTPUT_INSUFFICIENT_SIZE is not recoverable, so end execution. - if (fallbackStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) { + if (fallbackN == ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE) { + const ErrorStatus fallbackStatus = convertResultCodeToErrorStatus(fallbackN); executionCallback->notify(fallbackStatus, outputShapes, kNoTiming); return; } // Do not fallback twice if the ExecutionPlan is simple. - if (plan->isSimple()) { + if (plan.isSimple()) { + const ErrorStatus fallbackStatus = convertResultCodeToErrorStatus(fallbackN); executionCallback->notify(fallbackStatus, {}, kNoTiming); return; } @@ -498,14 +458,9 @@ static void asyncStartComputePartitioned(ExecutionBuilder* executionBuilder, // If the code has reached this point, a potentially recoverable error // occurred during the step executions. Instead, do a full execution // fallback on the CPU. - sp<ExecutionCallback> fallbackCallback; - int n = cpuFallbackFull(executionBuilder, &fallbackCallback); - if (n != ANEURALNETWORKS_NO_ERROR) { - executionCallback->notify(convertResultCodeToErrorStatus(n), {}, kNoTiming); - return; - } - executionCallback->notify(fallbackCallback->getStatus(), fallbackCallback->getOutputShapes(), - fallbackCallback->getTiming()); + auto [fullN, fullOutputShapes, fullTiming] = cpuFallbackFull(executionBuilder); + const ErrorStatus fullStatus = convertResultCodeToErrorStatus(fullN); + executionCallback->notify(fullStatus, fullOutputShapes, fullTiming); } int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback, @@ -558,7 +513,7 @@ int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback, VLOG(EXECUTION) << "ExecutionBuilder::compute (synchronous API)"; sp<ExecutionCallback> localSynchronizationCallback = new ExecutionCallback(); localSynchronizationCallback->setOnFinish(wrappedFinish); - asyncStartComputePartitioned(this, mPlan, controller, allowFallback, + asyncStartComputePartitioned(this, *mPlan, controller, allowFallback, localSynchronizationCallback); localSynchronizationCallback->wait(); if (mMeasureTiming) { @@ -579,24 +534,28 @@ int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback, executionCallback->setOnFinish(wrappedFinish); if (DeviceManager::get()->syncExecRuntime()) { VLOG(EXECUTION) << "ExecutionBuilder::compute (asynchronous API, non-threaded)"; - asyncStartComputePartitioned(this, mPlan, controller, allowFallback, executionCallback); + asyncStartComputePartitioned(this, *mPlan, controller, allowFallback, + executionCallback); } else { VLOG(EXECUTION) << "ExecutionBuilder::compute (asynchronous API)"; - std::thread thread(asyncStartComputePartitioned, this, mPlan, controller, allowFallback, - executionCallback); - executionCallback->bindThread(std::move(thread)); + std::thread asyncExecution([this, controller, allowFallback, executionCallback] { + asyncStartComputePartitioned(this, *mPlan, controller, allowFallback, + executionCallback); + }); + executionCallback->bindThread(std::move(asyncExecution)); } *synchronizationCallback = executionCallback; return ANEURALNETWORKS_NO_ERROR; } } -void ExecutionBuilder::initializeOutputShapes(std::vector<OutputShape>* outputShapes) const { - outputShapes->resize(mOutputs.size()); - for (uint32_t i = 0; i < mOutputs.size(); i++) { - (*outputShapes)[i].dimensions = mOutputs[i].dimensions; - (*outputShapes)[i].isSufficient = true; - } +std::vector<OutputShape> ExecutionBuilder::getInitialOutputShapes() const { + std::vector<OutputShape> outputShapes(mOutputs.size()); + std::transform(mOutputs.begin(), mOutputs.end(), outputShapes.begin(), + [](const auto& x) -> OutputShape { + return {.dimensions = x.dimensions, .isSufficient = true}; + }); + return outputShapes; } // Check if the dimensions "to" is updatable by dimensions "from", where "from" must @@ -741,11 +700,9 @@ bool StepExecutor::isCpu() const { return mDevice == DeviceManager::getCpuDevice(); } -int StepExecutor::startCompute(sp<ExecutionCallback>* synchronizationCallback, - const std::shared_ptr<ExecutionBurstController>& burstController) { +std::tuple<int, std::vector<OutputShape>, Timing> StepExecutor::compute( + const std::shared_ptr<ExecutionBurstController>& burstController) { CHECK(mPreparedModel != nullptr); - CHECK(synchronizationCallback != nullptr); - *synchronizationCallback = nullptr; if (VLOG_IS_ON(EXECUTION)) { logArguments("input", mInputs); @@ -757,19 +714,12 @@ int StepExecutor::startCompute(sp<ExecutionCallback>* synchronizationCallback, mPreparedModel->execute(mInputs, mOutputs, mMemories, burstController, measure); mExecutionBuilder->reportTiming(timing); - if (n != ANEURALNETWORKS_NO_ERROR && n != ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE) { - return n; - } - - const ErrorStatus status = convertResultCodeToErrorStatus(n); - *synchronizationCallback = new ExecutionCallback(); - (*synchronizationCallback)->notify_1_2(status, outputShapes, timing); - return ANEURALNETWORKS_NO_ERROR; + return {n, std::move(outputShapes), timing}; } // For cpuFallback{Partial,Full}, recompile the model on CPU and then start compute. -int StepExecutor::startComputeOnCpuFallback(sp<ExecutionCallback>* synchronizationCallback) { - NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "StepExecutor::startComputeOnCpuFallback"); +std::tuple<int, std::vector<OutputShape>, Timing> StepExecutor::computeOnCpuFallback() { + NNTRACE_RT(NNTRACE_PHASE_EXECUTION, "StepExecutor::computeOnCpuFallback"); VLOG(EXECUTION) << "Re-compile the model on CPU"; mDevice = DeviceManager::getCpuDevice(); mPreparedModel = nullptr; @@ -780,8 +730,10 @@ int StepExecutor::startComputeOnCpuFallback(sp<ExecutionCallback>* synchronizati static_cast<ExecutionPreference>(ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER); const auto [n, preparedModel] = mDevice->prepareModel(makeModel, preference, {}, {}); mPreparedModel = preparedModel; - NN_RETURN_IF_ERROR(n); - return startCompute(synchronizationCallback, /*burstController=*/nullptr); + if (n != ANEURALNETWORKS_NO_ERROR) { + return {n, {}, kNoTiming}; + } + return compute(/*burstController=*/nullptr); } } // namespace nn diff --git a/nn/runtime/ExecutionBuilder.h b/nn/runtime/ExecutionBuilder.h index f07335a01..3d8ab3e6c 100644 --- a/nn/runtime/ExecutionBuilder.h +++ b/nn/runtime/ExecutionBuilder.h @@ -69,7 +69,7 @@ class ExecutionBuilder { int burstCompute(BurstBuilder* burst) { return compute(nullptr, burst); } // Initialize output dimensional information from ModelArgumentInfo. - void initializeOutputShapes(std::vector<hal::OutputShape>* outputShapes) const; + std::vector<hal::OutputShape> getInitialOutputShapes() const; int getOutputOperandDimensions(uint32_t index, uint32_t* dimensions); int getOutputOperandRank(uint32_t index, uint32_t* rank); @@ -160,7 +160,7 @@ class StepExecutor { // is executing the entire model from the ExecutionBuilder). void mapInputsAndOutputsTrivially(); - // Update output shapes returned from ExecutionCallback to ExecutionBuilder. + // Update output shapes with shapes returned from execution. bool updateOutputShapes(const std::vector<hal::OutputShape>& from, std::vector<hal::OutputShape>* to); @@ -189,12 +189,12 @@ class StepExecutor { } // Executes using the (driver, preparedModel) specified at construction time. - int startCompute(sp<ExecutionCallback>* synchronizationCallback, - const std::shared_ptr<ExecutionBurstController>& burstController = nullptr); + std::tuple<int, std::vector<hal::OutputShape>, hal::Timing> compute( + const std::shared_ptr<ExecutionBurstController>& burstController = nullptr); // Re-compiles and executes using the CPU, regardless of the (driver, // preparedModel) specified at construction time. - int startComputeOnCpuFallback(sp<ExecutionCallback>* synchronizationCallback); + std::tuple<int, std::vector<hal::OutputShape>, hal::Timing> computeOnCpuFallback(); bool isCpu() const; diff --git a/nn/runtime/ExecutionPlan.cpp b/nn/runtime/ExecutionPlan.cpp index 901305216..ed38aec2f 100644 --- a/nn/runtime/ExecutionPlan.cpp +++ b/nn/runtime/ExecutionPlan.cpp @@ -71,8 +71,9 @@ int compile(const Device& device, const ModelBuilder& model, int executionPrefer *preparedModel = nullptr; std::optional<CacheToken> cacheToken; - if (device.isCachingSupported() && token->ok() && token->updateFromString(device.getName()) && - token->updateFromString(device.getVersionString()) && + if (device.isCachingSupported() && token->ok() && + token->updateFromString(device.getName().c_str()) && + token->updateFromString(device.getVersionString().c_str()) && token->update(&executionPreference, sizeof(executionPreference)) && token->finish()) { cacheToken.emplace(token->getCacheToken()); } @@ -996,13 +997,13 @@ class CanDo { CanDo() {} void initialize(const MetaModel& metaModel, std::shared_ptr<Device> device) { - device->getSupportedOperations(metaModel, &mSupportsOperationByIndex); + mSupportsOperationByIndex = device->getSupportedOperations(metaModel); } bool check(size_t operationIndex) const { return mSupportsOperationByIndex[operationIndex]; } private: - hidl_vec<bool> mSupportsOperationByIndex; + std::vector<bool> mSupportsOperationByIndex; }; } // anonymous namespace diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp index 34378b3fd..1be1bf68b 100644 --- a/nn/runtime/Manager.cpp +++ b/nn/runtime/Manager.cpp @@ -52,28 +52,38 @@ const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX // A Device with actual underlying driver class DriverDevice : public Device { public: - DriverDevice(std::string name, const sp<V1_0::IDevice>& device); - - // Returns true if succesfully initialized. - bool initialize(); - - const char* getName() const override { return mName.c_str(); } - const char* getVersionString() const override { return mVersionString.c_str(); } - int64_t getFeatureLevel() const override { return mInterface->getFeatureLevel(); } - int32_t getType() const override { return mInterface->getType(); } - hidl_vec<Extension> getSupportedExtensions() const override; - void getSupportedOperations(const MetaModel& metaModel, - hidl_vec<bool>* supportedOperations) const override; - PerformanceInfo getPerformance(OperandType type) const override; + // Create a DriverDevice from a name and an IDevice. + // Returns nullptr on failure. + static std::shared_ptr<DriverDevice> create(std::string name, sp<V1_0::IDevice> device); + + // Prefer using DriverDevice::create + DriverDevice(std::shared_ptr<VersionedIDevice> device); + + const std::string& getName() const override { return kInterface->getName(); } + const std::string& getVersionString() const override { return kInterface->getVersionString(); } + int64_t getFeatureLevel() const override { return kInterface->getFeatureLevel(); } + int32_t getType() const override { return kInterface->getType(); } + const std::vector<Extension>& getSupportedExtensions() const override { + return kInterface->getSupportedExtensions(); + } + std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const override; + PerformanceInfo getPerformance(OperandType type) const override { + const auto& capabilities = kInterface->getCapabilities(); + return lookup(capabilities.operandPerformance, type); + } PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const override { - return mCapabilities.relaxedFloat32toFloat16PerformanceScalar; + const auto& capabilities = kInterface->getCapabilities(); + return capabilities.relaxedFloat32toFloat16PerformanceScalar; } PerformanceInfo getRelaxedFloat32toFloat16PerformanceTensor() const override { - return mCapabilities.relaxedFloat32toFloat16PerformanceTensor; + const auto& capabilities = kInterface->getCapabilities(); + return capabilities.relaxedFloat32toFloat16PerformanceTensor; } bool isCachingSupported() const override { // Caching is supported if either of numModelCache or numDataCache is greater than 0. - return mNumCacheFiles.first > 0 || mNumCacheFiles.second > 0; + const auto [numModelCacheFiles, numDataCacheFiles] = + kInterface->getNumberOfCacheFilesNeeded(); + return numModelCacheFiles > 0 || numDataCacheFiles > 0; } std::pair<int, std::shared_ptr<PreparedModel>> prepareModel( @@ -88,12 +98,7 @@ class DriverDevice : public Device { std::pair<int, std::shared_ptr<PreparedModel>> prepareModelFromCacheInternal( const std::string& cacheDir, const CacheToken& token) const; - std::string mName; - std::string mVersionString; - const std::shared_ptr<VersionedIDevice> mInterface; - Capabilities mCapabilities; - hidl_vec<Extension> mSupportedExtensions; - std::pair<uint32_t, uint32_t> mNumCacheFiles; + const std::shared_ptr<VersionedIDevice> kInterface; #ifdef NN_DEBUGGABLE // For debugging: behavior of IDevice::getSupportedOperations for SampleDriver. @@ -124,104 +129,57 @@ class DriverPreparedModel : public PreparedModel { const std::shared_ptr<VersionedIPreparedModel> mPreparedModel; }; -DriverDevice::DriverDevice(std::string name, const sp<V1_0::IDevice>& device) - : mName(std::move(name)), mInterface(VersionedIDevice::create(mName, device)) {} - -// TODO: handle errors from initialize correctly -bool DriverDevice::initialize() { +DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device) + : kInterface(std::move(device)) { #ifdef NN_DEBUGGABLE static const char samplePrefix[] = "sample"; - - mSupported = (mName.substr(0, sizeof(samplePrefix) - 1) == samplePrefix) - ? getProp("debug.nn.sample.supported") - : 0; -#endif // NN_DEBUGGABLE - - ErrorStatus status = ErrorStatus::GENERAL_FAILURE; - - if (mInterface == nullptr) { - LOG(ERROR) << "DriverDevice contains invalid interface object."; - return false; - } - - std::tie(status, mCapabilities) = mInterface->getCapabilities(); - if (status != ErrorStatus::NONE) { - LOG(ERROR) << "IDevice::getCapabilities returned the error " << toString(status); - return false; - } - VLOG(MANAGER) << "Capab " << toString(mCapabilities); - - std::tie(status, mVersionString) = mInterface->getVersionString(); - // TODO(miaowang): add a validation test case for in case of error. - if (status != ErrorStatus::NONE) { - LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(status); - return false; - } - - std::tie(status, mSupportedExtensions) = mInterface->getSupportedExtensions(); - if (status != ErrorStatus::NONE) { - LOG(ERROR) << "IDevice::getSupportedExtensions returned the error " << toString(status); - return false; + if (getName().substr(0, sizeof(samplePrefix) - 1) == samplePrefix) { + mSupported = getProp("debug.nn.sample.supported"); } +#endif // NN_DEBUGGABLE +} - std::tie(status, mNumCacheFiles.first, mNumCacheFiles.second) = - mInterface->getNumberOfCacheFilesNeeded(); - if (status != ErrorStatus::NONE) { - LOG(WARNING) << "IDevice::getNumberOfCacheFilesNeeded returned the error " - << toString(status); - mNumCacheFiles = {0, 0}; - } - if (mNumCacheFiles.first > static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) || - mNumCacheFiles.second > static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES)) { - LOG(WARNING) - << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files " - "numModelCache = " - << mNumCacheFiles.first << ", numDataCache = " << mNumCacheFiles.second; - mNumCacheFiles = {0, 0}; +std::shared_ptr<DriverDevice> DriverDevice::create(std::string name, sp<V1_0::IDevice> device) { + CHECK(device != nullptr); + std::shared_ptr<VersionedIDevice> versionedDevice = + VersionedIDevice::create(name, std::move(device)); + if (versionedDevice == nullptr) { + LOG(ERROR) << "DriverDevice::create failed to create VersionedIDevice object for service " + << name; + return nullptr; } - return true; -} -hidl_vec<Extension> DriverDevice::getSupportedExtensions() const { - return mSupportedExtensions; + return std::make_shared<DriverDevice>(std::move(versionedDevice)); } -void DriverDevice::getSupportedOperations(const MetaModel& metaModel, - hidl_vec<bool>* outSupportedOperations) const { +std::vector<bool> DriverDevice::getSupportedOperations(const MetaModel& metaModel) const { // Query the driver for what it can do. ErrorStatus status = ErrorStatus::GENERAL_FAILURE; - hidl_vec<bool> supportedOperations; - std::tie(status, supportedOperations) = mInterface->getSupportedOperations(metaModel); + std::vector<bool> supportedOperations; + std::tie(status, supportedOperations) = kInterface->getSupportedOperations(metaModel); const Model& hidlModel = metaModel.getModel(); if (status != ErrorStatus::NONE) { LOG(ERROR) << "IDevice::getSupportedOperations returned the error " << toString(status); // Set the supported operation vectors to all false, so we won't use this driver. - outSupportedOperations->resize(hidlModel.operations.size()); - std::fill(outSupportedOperations->begin(), outSupportedOperations->end(), false); - return; + return std::vector<bool>(hidlModel.operations.size(), false); } if (supportedOperations.size() != hidlModel.operations.size()) { LOG(ERROR) << "IDevice::getSupportedOperations returned a vector of length " << supportedOperations.size() << " when expecting " << hidlModel.operations.size(); // Set the supported operation vectors to all false, so we won't use this driver. - outSupportedOperations->resize(hidlModel.operations.size()); - std::fill(outSupportedOperations->begin(), outSupportedOperations->end(), false); - return; + return std::vector<bool>(hidlModel.operations.size(), false); } - *outSupportedOperations = std::move(supportedOperations); - #ifdef NN_DEBUGGABLE if (mSupported != 1) { - return; + return supportedOperations; } - const uint32_t baseAccumulator = std::hash<std::string>{}(mName); - for (size_t operationIndex = 0; operationIndex < outSupportedOperations->size(); - operationIndex++) { - if (!(*outSupportedOperations)[operationIndex]) { + const uint32_t baseAccumulator = std::hash<std::string>{}(getName()); + for (size_t operationIndex = 0; operationIndex < supportedOperations.size(); operationIndex++) { + if (!supportedOperations[operationIndex]) { continue; } @@ -245,14 +203,12 @@ void DriverDevice::getSupportedOperations(const MetaModel& metaModel, accumulateOperands(operation.inputs); accumulateOperands(operation.outputs); if (accumulator & 1) { - (*outSupportedOperations)[operationIndex] = false; + supportedOperations[operationIndex] = false; } } #endif // NN_DEBUGGABLE -} -PerformanceInfo DriverDevice::getPerformance(OperandType type) const { - return lookup(mCapabilities.operandPerformance, type); + return supportedOperations; } // Opens cache file by filename and sets the handle to the opened fd. Returns false on fail. The @@ -322,7 +278,7 @@ static bool getCacheHandles(const std::string& cacheDir, const CacheToken& token static std::pair<int, std::shared_ptr<PreparedModel>> prepareModelCheck( ErrorStatus status, const std::shared_ptr<VersionedIPreparedModel>& preparedModel, - const char* prepareName, const char* serviceName) { + const char* prepareName, const std::string& serviceName) { if (status != ErrorStatus::NONE) { LOG(ERROR) << prepareName << " on " << serviceName << " failed: " << "prepareReturnStatus=" << toString(status); @@ -344,7 +300,7 @@ std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelInterna hidl_vec<hidl_handle> modelCache, dataCache; if (!maybeToken.has_value() || - !getCacheHandles(cacheDir, *maybeToken, mNumCacheFiles, + !getCacheHandles(cacheDir, *maybeToken, kInterface->getNumberOfCacheFilesNeeded(), /*createIfNotExist=*/true, &modelCache, &dataCache)) { modelCache.resize(0); dataCache.resize(0); @@ -353,7 +309,7 @@ std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelInterna static const CacheToken kNullToken{}; const CacheToken token = maybeToken.value_or(kNullToken); const auto [status, preparedModel] = - mInterface->prepareModel(model, preference, modelCache, dataCache, token); + kInterface->prepareModel(model, preference, modelCache, dataCache, token); return prepareModelCheck(status, preparedModel, "prepareModel", getName()); } @@ -365,13 +321,13 @@ std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelFromCac VLOG(COMPILATION) << "prepareModelFromCache"; hidl_vec<hidl_handle> modelCache, dataCache; - if (!getCacheHandles(cacheDir, token, mNumCacheFiles, + if (!getCacheHandles(cacheDir, token, kInterface->getNumberOfCacheFilesNeeded(), /*createIfNotExist=*/false, &modelCache, &dataCache)) { return {ANEURALNETWORKS_OP_FAILED, nullptr}; } const auto [status, preparedModel] = - mInterface->prepareModelFromCache(modelCache, dataCache, token); + kInterface->prepareModelFromCache(modelCache, dataCache, token); return prepareModelCheck(status, preparedModel, "prepareModelFromCache", getName()); } @@ -571,13 +527,14 @@ class CpuDevice : public Device { return instance; } - const char* getName() const override { return kName.c_str(); } - const char* getVersionString() const override { return kVersionString.c_str(); } + const std::string& getName() const override { return kName; } + const std::string& getVersionString() const override { return kVersionString; } int64_t getFeatureLevel() const override { return kFeatureLevel; } int32_t getType() const override { return ANEURALNETWORKS_DEVICE_CPU; } - hidl_vec<Extension> getSupportedExtensions() const override { return {/* No extensions. */}; } - void getSupportedOperations(const MetaModel& metaModel, - hidl_vec<bool>* supportedOperations) const override; + const std::vector<Extension>& getSupportedExtensions() const override { + return kSupportedExtensions; + } + std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const override; PerformanceInfo getPerformance(OperandType) const override { return kPerformance; } PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const override { return kPerformance; @@ -600,6 +557,7 @@ class CpuDevice : public Device { // Since the performance is a ratio compared to the CPU performance, // by definition the performance of the CPU is 1.0. const PerformanceInfo kPerformance = {.execTime = 1.0f, .powerUsage = 1.0f}; + const std::vector<Extension> kSupportedExtensions{/* No extensions. */}; }; // A special abstracted PreparedModel for the CPU, constructed by CpuDevice. @@ -629,11 +587,10 @@ class CpuPreparedModel : public PreparedModel { const std::vector<RunTimePoolInfo> mModelPoolInfos; }; -void CpuDevice::getSupportedOperations(const MetaModel& metaModel, - hidl_vec<bool>* supportedOperations) const { +std::vector<bool> CpuDevice::getSupportedOperations(const MetaModel& metaModel) const { const Model& hidlModel = metaModel.getModel(); const size_t count = hidlModel.operations.size(); - hidl_vec<bool> result(count); + std::vector<bool> result(count, false); for (size_t i = 0; i < count; i++) { // TODO(b/119870033): Decide whether and how post-P operations would be supported on CPU. // We may want to use the slicer for CpuDevice just as we do for @@ -642,7 +599,7 @@ void CpuDevice::getSupportedOperations(const MetaModel& metaModel, result[i] = !isExtensionOperationType(operationType) && operationType != OperationType::OEM_OPERATION; } - *supportedOperations = std::move(result); + return result; } std::pair<int, std::shared_ptr<PreparedModel>> CpuDevice::prepareModel( @@ -751,42 +708,33 @@ std::shared_ptr<Device> DeviceManager::getCpuDevice() { std::shared_ptr<Device> DeviceManager::forTest_makeDriverDevice(const std::string& name, const sp<V1_0::IDevice>& device) { - auto driverDevice = std::make_shared<DriverDevice>(name, device); - CHECK(driverDevice->initialize()); + const auto driverDevice = DriverDevice::create(name, device); + CHECK(driverDevice != nullptr); return driverDevice; } void DeviceManager::findAvailableDevices() { - using ::android::hidl::manager::V1_2::IServiceManager; VLOG(MANAGER) << "findAvailableDevices"; - sp<IServiceManager> manager = hardware::defaultServiceManager1_2(); - if (manager == nullptr) { - LOG(ERROR) << "Unable to open defaultServiceManager"; - return; + // register driver devices + const auto names = hardware::getAllHalInstanceNames(V1_0::IDevice::descriptor); + for (const auto& name : names) { + VLOG(MANAGER) << "Found interface " << name; + sp<V1_0::IDevice> device = V1_0::IDevice::getService(name); + if (device == nullptr) { + LOG(ERROR) << "Got a null IDEVICE for " << name; + continue; + } + registerDevice(name, device); } - manager->listManifestByInterface( - V1_0::IDevice::descriptor, [this](const hidl_vec<hidl_string>& names) { - for (const auto& name : names) { - VLOG(MANAGER) << "Found interface " << name.c_str(); - sp<V1_0::IDevice> device = V1_0::IDevice::getService(name); - if (device == nullptr) { - LOG(ERROR) << "Got a null IDEVICE for " << name.c_str(); - continue; - } - registerDevice(name.c_str(), device); - } - }); - // register CPU fallback device mDevices.push_back(CpuDevice::get()); mDevicesCpuOnly.push_back(CpuDevice::get()); } -void DeviceManager::registerDevice(const char* name, const sp<V1_0::IDevice>& device) { - auto d = std::make_shared<DriverDevice>(name, device); - if (d->initialize()) { +void DeviceManager::registerDevice(const std::string& name, const sp<V1_0::IDevice>& device) { + if (const auto d = DriverDevice::create(name, device)) { mDevices.push_back(d); } } diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h index 82042d108..cece348ee 100644 --- a/nn/runtime/Manager.h +++ b/nn/runtime/Manager.h @@ -68,15 +68,14 @@ class Device { virtual ~Device() = default; // Introspection methods returning device information - virtual const char* getName() const = 0; - virtual const char* getVersionString() const = 0; + virtual const std::string& getName() const = 0; + virtual const std::string& getVersionString() const = 0; virtual int64_t getFeatureLevel() const = 0; virtual int32_t getType() const = 0; - virtual hal::hidl_vec<hal::Extension> getSupportedExtensions() const = 0; + virtual const std::vector<hal::Extension>& getSupportedExtensions() const = 0; // See the MetaModel class in MetaModel.h for more details. - virtual void getSupportedOperations(const MetaModel& metaModel, - hal::hidl_vec<bool>* supportedOperations) const = 0; + virtual std::vector<bool> getSupportedOperations(const MetaModel& metaModel) const = 0; virtual hal::PerformanceInfo getPerformance(hal::OperandType type) const = 0; virtual hal::PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const = 0; @@ -142,7 +141,7 @@ class DeviceManager { } // Register a test device. - void forTest_registerDevice(const char* name, const sp<hal::V1_0::IDevice>& device) { + void forTest_registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device) { registerDevice(name, device); } @@ -166,7 +165,7 @@ class DeviceManager { DeviceManager(); // Adds a device for the manager to use. - void registerDevice(const char* name, const sp<hal::V1_0::IDevice>& device); + void registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device); void findAvailableDevices(); diff --git a/nn/runtime/NeuralNetworks.cpp b/nn/runtime/NeuralNetworks.cpp index 17667fd50..3542c24b1 100644 --- a/nn/runtime/NeuralNetworks.cpp +++ b/nn/runtime/NeuralNetworks.cpp @@ -582,7 +582,7 @@ int ANeuralNetworksDevice_getName(const ANeuralNetworksDevice* device, const cha return ANEURALNETWORKS_UNEXPECTED_NULL; } const Device* d = reinterpret_cast<const Device*>(device); - *name = d->getName(); + *name = d->getName().c_str(); return ANEURALNETWORKS_NO_ERROR; } @@ -592,7 +592,7 @@ int ANeuralNetworksDevice_getVersion(const ANeuralNetworksDevice* device, const return ANEURALNETWORKS_UNEXPECTED_NULL; } const Device* d = reinterpret_cast<const Device*>(device); - *version = d->getVersionString(); + *version = d->getVersionString().c_str(); return ANEURALNETWORKS_NO_ERROR; } @@ -665,8 +665,7 @@ int ANeuralNetworksModel_getSupportedOperationsForDevices( Device* d = reinterpret_cast<Device*>(const_cast<ANeuralNetworksDevice*>(devices[i])); const MetaModel metaModel(hidlModel, DeviceManager::get()->strictSlicing()); - hidl_vec<bool> supportsByDevice; - d->getSupportedOperations(metaModel, &supportsByDevice); + const std::vector<bool> supportsByDevice = d->getSupportedOperations(metaModel); for (uint32_t j = 0; j < supportsByDevice.size(); j++) { uint32_t originalIdx = opMap[j]; supportedOps[originalIdx] |= supportsByDevice[j]; @@ -1186,16 +1185,12 @@ int ANeuralNetworksDevice_getExtensionSupport(const ANeuralNetworksDevice* devic return ANEURALNETWORKS_UNEXPECTED_NULL; } - Device* d = reinterpret_cast<Device*>(const_cast<ANeuralNetworksDevice*>(device)); - hidl_vec<Extension> supportedExtensions = d->getSupportedExtensions(); - - *isExtensionSupported = false; - for (const Extension& supportedExtension : supportedExtensions) { - if (supportedExtension.name == extensionName) { - *isExtensionSupported = true; - break; - } - } + const Device* d = reinterpret_cast<const Device*>(device); + const auto& supportedExtensions = d->getSupportedExtensions(); + *isExtensionSupported = std::any_of(supportedExtensions.begin(), supportedExtensions.end(), + [extensionName](const auto& supportedExtension) { + return supportedExtension.name == extensionName; + }); return ANEURALNETWORKS_NO_ERROR; } diff --git a/nn/runtime/TypeManager.cpp b/nn/runtime/TypeManager.cpp index 854e28ab9..40b0e3d44 100644 --- a/nn/runtime/TypeManager.cpp +++ b/nn/runtime/TypeManager.cpp @@ -181,7 +181,7 @@ bool TypeManager::isExtensionsUseAllowed(const AppPackageInfo& appPackageInfo, void TypeManager::findAvailableExtensions() { for (const std::shared_ptr<Device>& device : mDeviceManager->getDrivers()) { - for (const Extension extension : device->getSupportedExtensions()) { + for (const Extension& extension : device->getSupportedExtensions()) { registerExtension(extension, device->getName()); } } diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index ba6e2af7c..8da4e7c1f 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -306,21 +306,88 @@ std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExec std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName, sp<V1_0::IDevice> device) { + CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object."; + auto core = Core::create(std::move(device)); if (!core.has_value()) { LOG(ERROR) << "VersionedIDevice::create failed to create Core."; return nullptr; } - // return a valid VersionedIDevice object - return std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value())); + // create and initialize a VersionedIDevice object + const auto versionedIDevice = + std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value())); + if (!versionedIDevice->initializeInternal()) { + LOG(ERROR) << "VersionedIDevice failed to initialize"; + return nullptr; + } + + // return a valid, initialized VersionedIDevice object + return versionedIDevice; } VersionedIDevice::VersionedIDevice(std::string serviceName, Core core) : mServiceName(std::move(serviceName)), mCore(std::move(core)) {} +bool VersionedIDevice::initializeInternal() { + auto [capabilitiesStatus, capabilities] = getCapabilitiesInternal(); + if (capabilitiesStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getCapabilities* returned the error " + << toString(capabilitiesStatus); + return false; + } + VLOG(MANAGER) << "Capab " << toString(capabilities); + + const auto [versionStatus, versionString] = getVersionStringInternal(); + // TODO(miaowang): add a validation test case for in case of error. + if (versionStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(versionStatus); + return false; + } + + const int32_t type = getTypeInternal(); + if (type == -1) { + LOG(ERROR) << "IDevice::getType returned an error"; + return false; + } + + const auto [extensionsStatus, supportedExtensions] = getSupportedExtensionsInternal(); + if (extensionsStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getSupportedExtensions returned the error " + << toString(extensionsStatus); + return false; + } + + const auto [cacheFilesStatus, numModelCacheFiles, numDataCacheFiles] = + getNumberOfCacheFilesNeededInternal(); + if (cacheFilesStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getNumberOfCacheFilesNeeded returned the error " + << toString(cacheFilesStatus); + return false; + } + + // The following limit is enforced by VTS + constexpr uint32_t maxNumCacheFiles = + static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES); + if (numModelCacheFiles > maxNumCacheFiles || numDataCacheFiles > maxNumCacheFiles) { + LOG(ERROR) + << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files: " + "numModelCacheFiles = " + << numModelCacheFiles << ", numDataCacheFiles = " << numDataCacheFiles + << ", maxNumCacheFiles = " << maxNumCacheFiles; + return false; + } + + // set internal members + mCapabilities = std::move(capabilities); + mSupportedExtensions = supportedExtensions; + mType = type; + mVersionString = versionString; + mNumberOfCacheFilesNeeded = {numModelCacheFiles, numDataCacheFiles}; + return true; +} + std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) { - // verify input CHECK(device != nullptr) << "VersionedIDevice::Core::create passed invalid device object."; // create death handler object @@ -480,7 +547,7 @@ Return<T_Return> VersionedIDevice::recoverable( return ret; } -std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() const { +std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilitiesInternal() const { const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}}; std::pair<ErrorStatus, Capabilities> result; @@ -559,7 +626,12 @@ std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() const { return {ErrorStatus::DEVICE_UNAVAILABLE, {}}; } -std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() const { +const Capabilities& VersionedIDevice::getCapabilities() const { + return mCapabilities; +} + +std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensionsInternal() + const { const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}}; // version 1.2+ HAL @@ -590,6 +662,10 @@ std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtens return {ErrorStatus::DEVICE_UNAVAILABLE, {}}; } +const std::vector<Extension>& VersionedIDevice::getSupportedExtensions() const { + return mSupportedExtensions; +} + std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations( const MetaModel& metaModel) const { const std::pair<ErrorStatus, hidl_vec<bool>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}}; @@ -967,7 +1043,7 @@ int64_t VersionedIDevice::getFeatureLevel() const { } } -int32_t VersionedIDevice::getType() const { +int32_t VersionedIDevice::getTypeInternal() const { constexpr int32_t kFailure = -1; // version 1.2+ HAL @@ -988,12 +1064,22 @@ int32_t VersionedIDevice::getType() const { return result; } - // version too low or no device available - LOG(INFO) << "Unknown NNAPI device type."; - return ANEURALNETWORKS_DEVICE_UNKNOWN; + // version too low + if (getDevice<V1_0::IDevice>() != nullptr) { + LOG(INFO) << "Unknown NNAPI device type."; + return ANEURALNETWORKS_DEVICE_UNKNOWN; + } + + // No device available + LOG(ERROR) << "Could not handle getType"; + return kFailure; +} + +int32_t VersionedIDevice::getType() const { + return mType; } -std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() const { +std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionStringInternal() const { const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""}; // version 1.2+ HAL @@ -1023,7 +1109,12 @@ std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() const { return kFailure; } -std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const { +const std::string& VersionedIDevice::getVersionString() const { + return mVersionString; +} + +std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeededInternal() + const { constexpr std::tuple<ErrorStatus, uint32_t, uint32_t> kFailure = {ErrorStatus::GENERAL_FAILURE, 0, 0}; @@ -1055,5 +1146,13 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi return kFailure; } +std::pair<uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const { + return mNumberOfCacheFilesNeeded; +} + +const std::string& VersionedIDevice::getName() const { + return mServiceName; +} + } // namespace nn } // namespace android diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h index 87e776507..7f6f11af3 100644 --- a/nn/runtime/VersionedInterfaces.h +++ b/nn/runtime/VersionedInterfaces.h @@ -94,13 +94,9 @@ class VersionedIDevice { /** * Gets the capabilities of a driver. * - * @return status Error status of the call, must be: - * - NONE if successful - * - DEVICE_UNAVAILABLE if driver is offline or busy - * - GENERAL_FAILURE if there is an unspecified error * @return capabilities Capabilities of the driver. */ - std::pair<hal::ErrorStatus, hal::Capabilities> getCapabilities() const; + const hal::Capabilities& getCapabilities() const; /** * Gets information about extensions supported by the driver implementation. @@ -111,13 +107,9 @@ class VersionedIDevice { * All extension operations and operands must be fully supported for the * extension to appear in the list of supported extensions. * - * @return status Error status of the call, must be: - * - NONE if successful - * - DEVICE_UNAVAILABLE if driver is offline or busy - * - GENERAL_FAILURE if there is an unspecified error * @return extensions A list of supported extensions. */ - std::pair<hal::ErrorStatus, hal::hidl_vec<hal::Extension>> getSupportedExtensions() const; + const std::vector<hal::Extension>& getSupportedExtensions() const; /** * Gets the supported operations in a MetaModel. @@ -340,14 +332,12 @@ class VersionedIDevice { /** * Returns the device type of a driver. * - * @return deviceType The type of a given device, which can help application developers - * developers to distribute Machine Learning workloads and other workloads - * such as graphical rendering. E.g., for an app which renders AR scenes - * based on real time object detection results, the developer could choose - * an ACCELERATOR type device for ML workloads, and reserve GPU for - * graphical rendering. - * Return -1 if the driver is offline or busy, or the query resulted in - * an unspecified error. + * @return deviceType The type of a given device, which can help application + * developers to distribute Machine Learning workloads and other + * workloads such as graphical rendering. E.g., for an app which renders + * AR scenes based on real time object detection results, the developer + * could choose an ACCELERATOR type device for ML workloads, and reserve + * GPU for graphical rendering. */ int32_t getType() const; @@ -371,15 +361,9 @@ class VersionedIDevice { * the driver cannot meet that requirement because of bugs or certain optimizations. * The application can filter out versions of these drivers. * - * @return status Error status returned from querying the version string. Must be: - * - NONE if the query was successful - * - DEVICE_UNAVAILABLE if driver is offline or busy - * - GENERAL_FAILURE if the query resulted in an - * unspecified error * @return version The version string of the device implementation. - * Must have nonzero length if the query is successful, and must be an empty string if not. */ - std::pair<hal::ErrorStatus, hal::hidl_string> getVersionString() const; + const std::string& getVersionString() const; /** * Gets the caching requirements of the driver implementation. @@ -408,10 +392,6 @@ class VersionedIDevice { * IDevice::prepareModelFromCache or providing cache file descriptors to * IDevice::prepareModel_1_2. * - * @return status Error status of the call, must be: - * - NONE if successful - * - DEVICE_UNAVAILABLE if driver is offline or busy - * - GENERAL_FAILURE if there is an unspecified error * @return numModelCache An unsigned integer indicating how many files for model cache * the driver needs to cache a single prepared model. It must * be less than or equal to Constant::MAX_NUMBER_OF_CACHE_FILES. @@ -419,9 +399,35 @@ class VersionedIDevice { * the driver needs to cache a single prepared model. It must * be less than or equal to Constant::MAX_NUMBER_OF_CACHE_FILES. */ - std::tuple<hal::ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const; + std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const; + + /** + * Returns the name of the service. + * + * @return Name of the service. + */ + const std::string& getName() const; private: + // initializeInternal is called once during VersionedIDevice creation. + // 'true' indicates successful initialization. + bool initializeInternal(); + + // internal helper methods + std::pair<hal::ErrorStatus, hal::Capabilities> getCapabilitiesInternal() const; + std::pair<hal::ErrorStatus, hal::hidl_vec<hal::Extension>> getSupportedExtensionsInternal() + const; + int32_t getTypeInternal() const; + std::pair<hal::ErrorStatus, hal::hidl_string> getVersionStringInternal() const; + std::tuple<hal::ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeededInternal() const; + + // internal members for the cached results of the internal methods above + hal::Capabilities mCapabilities; + std::vector<hal::Extension> mSupportedExtensions; + int32_t mType; + std::string mVersionString; + std::pair<uint32_t, uint32_t> mNumberOfCacheFilesNeeded; + /** * This is a utility class for VersionedIDevice that encapsulates a * V1_0::IDevice, any appropriate downcasts to newer interfaces, and a diff --git a/nn/runtime/test/TestCompilationCaching.cpp b/nn/runtime/test/TestCompilationCaching.cpp index c061e853d..bce2836d2 100644 --- a/nn/runtime/test/TestCompilationCaching.cpp +++ b/nn/runtime/test/TestCompilationCaching.cpp @@ -14,12 +14,14 @@ * limitations under the License. */ +#include <android-base/scopeguard.h> #include <gtest/gtest.h> #include <cstdlib> #include <filesystem> #include <numeric> #include <string> +#include <string_view> #include <tuple> #include <vector> @@ -48,13 +50,32 @@ namespace { enum class HasCalledPrepareModel { NO, WITHOUT_CACHING, WITH_CACHING }; +// Print HasCalledPrepareModel enum for better GTEST failure messages +std::ostream& operator<<(std::ostream& os, HasCalledPrepareModel hasCalledPrepareModel) { + switch (hasCalledPrepareModel) { + case HasCalledPrepareModel::NO: + return os << "NO"; + case HasCalledPrepareModel::WITHOUT_CACHING: + return os << "WITHOUT_CACHING"; + case HasCalledPrepareModel::WITH_CACHING: + return os << "WITH_CACHING"; + } + CHECK(false) << "HasCalledPrepareModel print called with invalid code " + << static_cast<int>(hasCalledPrepareModel); + return os; +} + +// Whether the driver is expected to be registered because it can pass initialization. +bool canDeviceBeRegistered(ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) { + constexpr uint32_t maxNumCacheFiles = + static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES); + return error == ErrorStatus::NONE && numModelCache <= maxNumCacheFiles && + numDataCache <= maxNumCacheFiles; +} + // Whether the driver supports caching based on the returns from getNumberOfCacheFilesNeeded. -bool isCachingSupportedAndNoError(ErrorStatus error, uint32_t numModelCache, - uint32_t numDataCache) { - return error == ErrorStatus::NONE && - numModelCache <= static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) && - numDataCache <= static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) && - (numModelCache != 0 || numDataCache != 0); +bool isCachingSupported(uint32_t numModelCache, uint32_t numDataCache) { + return numModelCache != 0 || numDataCache != 0; } // This is an IDevice for testing purposes which overrides several methods from sample driver: @@ -99,9 +120,10 @@ class CachingDriver : public sample_driver::SampleDriver { }; public: - CachingDriver(const char* name, ErrorStatus errorStatusGetNumCacheFiles, uint32_t numModelCache, - uint32_t numDataCache, ErrorStatus errorStatusPrepareFromCache) - : SampleDriver(name), + CachingDriver(std::string_view name, ErrorStatus errorStatusGetNumCacheFiles, + uint32_t numModelCache, uint32_t numDataCache, + ErrorStatus errorStatusPrepareFromCache) + : SampleDriver(name.data()), mErrorStatusGetNumCacheFiles(errorStatusGetNumCacheFiles), mNumModelCache(numModelCache), mNumDataCache(numDataCache), @@ -181,8 +203,7 @@ class CachingDriver : public sample_driver::SampleDriver { private: // Checks the number of cache files passed to the driver from runtime. void checkNumberOfCacheHandles(size_t modelCache, size_t dataCache) { - if (isCachingSupportedAndNoError(mErrorStatusGetNumCacheFiles, mNumModelCache, - mNumDataCache)) { + if (isCachingSupported(mNumModelCache, mNumDataCache)) { if (modelCache != 0 || dataCache != 0) { ASSERT_EQ(modelCache, mNumModelCache); ASSERT_EQ(dataCache, mNumDataCache); @@ -239,12 +260,69 @@ void CreateBroadcastAddModel(test_wrapper::Model* model) { ASSERT_EQ(model->finish(), Result::NO_ERROR); } -// Test model compilation with a driver parameterized with +void getDeviceWithName(std::string_view deviceName, const ANeuralNetworksDevice** outputDevice) { + uint32_t numDevices = 0; + ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR); + EXPECT_GE(numDevices, (uint32_t)1); + + int numMatchingDevices = 0; + for (uint32_t i = 0; i < numDevices; i++) { + ANeuralNetworksDevice* device = nullptr; + ASSERT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR); + + const char* buffer = nullptr; + ASSERT_EQ(ANeuralNetworksDevice_getName(device, &buffer), ANEURALNETWORKS_NO_ERROR); + if (deviceName == buffer) { + *outputDevice = device; + numMatchingDevices++; + } + } + + EXPECT_LE(numMatchingDevices, 1); +} + +// Test device registration with a driver parameterized with // - ErrorStatus returning from getNumberOfCacheFilesNeeded // - Number of model cache files returning from getNumberOfCacheFilesNeeded // - Number of data cache files returning from getNumberOfCacheFilesNeeded +using DeviceRegistrationTestParam = std::tuple<ErrorStatus, uint32_t, uint32_t>; + +class DeviceRegistrationTest : public ::testing::TestWithParam<DeviceRegistrationTestParam> { + protected: + static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching"; + const ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam()); + const uint32_t kNumModelCache = std::get<1>(GetParam()); + const uint32_t kNumDataCache = std::get<2>(GetParam()); + const sp<CachingDriver> kDriver = + new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache, + kNumDataCache, ErrorStatus::NONE); +}; + +TEST_P(DeviceRegistrationTest, CachingFailure) { + if (DeviceManager::get()->getUseCpuOnly()) { + return; + } + + DeviceManager::get()->forTest_registerDevice(kDeviceName.data(), kDriver); + const auto cleanup = android::base::make_scope_guard( + [] { DeviceManager::get()->forTest_reInitializeDeviceList(); }); + + // get device + const ANeuralNetworksDevice* device = nullptr; + getDeviceWithName(kDeviceName, &device); + + // check if device registeration matches expectations + const bool isDeviceRegistered = (device != nullptr); + const bool expectDeviceToBeRegistered = + canDeviceBeRegistered(kErrorStatusGetNumCacheFiles, kNumModelCache, kNumDataCache); + ASSERT_EQ(isDeviceRegistered, expectDeviceToBeRegistered); +} + +// Test model compilation with a driver parameterized with +// - Number of model cache files returning from getNumberOfCacheFilesNeeded +// - Number of data cache files returning from getNumberOfCacheFilesNeeded // - ErrorStatus returning from prepareModelFromCache -using CompilationCachingTestParam = std::tuple<ErrorStatus, uint32_t, uint32_t, ErrorStatus>; +using CompilationCachingTestParam = std::tuple<uint32_t, uint32_t, ErrorStatus>; class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachingTestParam> { protected: @@ -254,9 +332,6 @@ class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachin ASSERT_NE(cacheDir, nullptr); mCacheDir = cacheDir; CreateBroadcastAddModel(&mModel); - mToken = std::vector<uint8_t>(ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN, 0); - mIsCachingSupportedAndNoError = isCachingSupportedAndNoError(kErrorStatusGetNumCacheFiles, - kNumModelCache, kNumDataCache); } virtual void TearDown() override { @@ -266,38 +341,26 @@ class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachin } void compileModel(const sp<CachingDriver>& driver, bool withToken) { - DeviceManager::get()->forTest_registerDevice(kDeviceName, driver); - - // Make device list including only a single driver device. - uint32_t numDevices = 0; - EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR); - EXPECT_GE(numDevices, (uint32_t)1); - std::vector<ANeuralNetworksDevice*> devices; - for (uint32_t i = 0; i < numDevices; i++) { - ANeuralNetworksDevice* device = nullptr; - EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR); - const char* buffer = nullptr; - int result = ANeuralNetworksDevice_getName(device, &buffer); - if (result == ANEURALNETWORKS_NO_ERROR && strcmp(buffer, kDeviceName) == 0) { - devices.push_back(device); - break; - } - } - ASSERT_EQ(devices.size(), 1u); + DeviceManager::get()->forTest_registerDevice(kDeviceName.data(), driver); + const auto cleanup = android::base::make_scope_guard( + [] { DeviceManager::get()->forTest_reInitializeDeviceList(); }); + + // Get a handle to the single driver device matching kDeviceName. + const ANeuralNetworksDevice* device = nullptr; + getDeviceWithName(kDeviceName, &device); + ASSERT_NE(device, nullptr); // Compile the model with the device. ANeuralNetworksCompilation* compilation = nullptr; - ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), devices.data(), - devices.size(), &compilation), + ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), &device, 1, + &compilation), ANEURALNETWORKS_NO_ERROR); if (withToken) { ASSERT_EQ(ANeuralNetworksCompilation_setCaching(compilation, mCacheDir.c_str(), - mToken.data()), + kToken.data()), ANEURALNETWORKS_NO_ERROR); } ASSERT_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR); - - DeviceManager::get()->forTest_reInitializeDeviceList(); } void createCache() { @@ -306,31 +369,29 @@ class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachin compileModel(driver, /*withToken=*/true); } - static constexpr char kDeviceName[] = "deviceTestCompilationCaching"; - const ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam()); - const uint32_t kNumModelCache = std::get<1>(GetParam()); - const uint32_t kNumDataCache = std::get<2>(GetParam()); - const ErrorStatus kErrorStatusPrepareFromCache = std::get<3>(GetParam()); - bool mIsCachingSupportedAndNoError; + static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching"; + const uint32_t kNumModelCache = std::get<0>(GetParam()); + const uint32_t kNumDataCache = std::get<1>(GetParam()); + const ErrorStatus kErrorStatusPrepareFromCache = std::get<2>(GetParam()); + const bool kIsCachingSupported = isCachingSupported(kNumModelCache, kNumDataCache); test_wrapper::Model mModel; std::string mCacheDir; - std::vector<uint8_t> mToken; + const CacheToken kToken{}; }; TEST_P(CompilationCachingTest, TokenProvidedAndCacheNotExist) { if (DeviceManager::get()->getUseCpuOnly()) { return; } - sp<CachingDriver> driver = - new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache, - kNumDataCache, kErrorStatusPrepareFromCache); + sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache, + kNumDataCache, kErrorStatusPrepareFromCache); compileModel(driver, /*withToken=*/true); // When cache file does not exist, the runtime should never call prepareModelFromCache. - EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), false); + EXPECT_FALSE(driver->hasCalledPrepareModelFromCache()); // The runtime should call prepareModel_1_2. It should request caching iff caching supported. - EXPECT_EQ(driver->hasCalledPrepareModel(), mIsCachingSupportedAndNoError + EXPECT_EQ(driver->hasCalledPrepareModel(), kIsCachingSupported ? HasCalledPrepareModel::WITH_CACHING : HasCalledPrepareModel::WITHOUT_CACHING); } @@ -340,16 +401,15 @@ TEST_P(CompilationCachingTest, TokenProvidedAndCacheExist) { return; } createCache(); - sp<CachingDriver> driver = - new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache, - kNumDataCache, kErrorStatusPrepareFromCache); + sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache, + kNumDataCache, kErrorStatusPrepareFromCache); compileModel(driver, /*withToken=*/true); // When cache files exist, the runtime should call prepareModelFromCache iff caching supported. - EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), mIsCachingSupportedAndNoError); + EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), kIsCachingSupported); HasCalledPrepareModel expectHasCalledPrepareModel; - if (mIsCachingSupportedAndNoError) { + if (kIsCachingSupported) { if (kErrorStatusPrepareFromCache == ErrorStatus::NONE) { // The runtime should not call prepareModel_1_2 iff caching supported and // prepareModelFromCache succeeds. @@ -370,14 +430,13 @@ TEST_P(CompilationCachingTest, TokenNotProvided) { if (DeviceManager::get()->getUseCpuOnly()) { return; } - sp<CachingDriver> driver = - new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache, - kNumDataCache, kErrorStatusPrepareFromCache); + sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache, + kNumDataCache, kErrorStatusPrepareFromCache); compileModel(driver, /*withToken=*/false); // When no NDK token is provided by the client, the runtime should never call // prepareModelFromCache or request caching with prepareModel_1_2. - EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), false); + EXPECT_FALSE(driver->hasCalledPrepareModelFromCache()); EXPECT_EQ(driver->hasCalledPrepareModel(), HasCalledPrepareModel::WITHOUT_CACHING); } @@ -386,12 +445,18 @@ static const auto kErrorStatusGetNumCacheFilesChoices = static const auto kNumCacheChoices = testing::Values(0ul, 1ul, static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES), static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) + 1); +static const auto kNumValidCacheChoices = + testing::Values(0ul, 1ul, static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES)); static const auto kErrorStatusPrepareFromCacheChoices = testing::Values(ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE, ErrorStatus::DEVICE_UNAVAILABLE, ErrorStatus::INVALID_ARGUMENT); -INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest, +INSTANTIATE_TEST_CASE_P(TestCompilationCaching, DeviceRegistrationTest, testing::Combine(kErrorStatusGetNumCacheFilesChoices, kNumCacheChoices, - kNumCacheChoices, kErrorStatusPrepareFromCacheChoices)); + kNumCacheChoices)); + +INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest, + testing::Combine(kNumValidCacheChoices, kNumValidCacheChoices, + kErrorStatusPrepareFromCacheChoices)); } // namespace diff --git a/nn/runtime/test/TestIntrospectionControl.cpp b/nn/runtime/test/TestIntrospectionControl.cpp index 9d0cbe6c3..76a6f5ee2 100644 --- a/nn/runtime/test/TestIntrospectionControl.cpp +++ b/nn/runtime/test/TestIntrospectionControl.cpp @@ -229,9 +229,9 @@ TEST_F(IntrospectionControlTest, SimpleAddModel) { // Verify that the mCompilation is actually using the "test-all" device. CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(mCompilation); - const char* deviceNameBuffer = + const std::string& deviceNameBuffer = c->forTest_getExecutionPlan().forTest_simpleGetDevice()->getName(); - EXPECT_TRUE(driverName.compare(deviceNameBuffer) == 0); + EXPECT_EQ(driverName, deviceNameBuffer); float input1[2] = {1.0f, 2.0f}; float input2[2] = {3.0f, 4.0f}; @@ -655,7 +655,7 @@ TEST_P(TimingTest, Test) { switch (kDriverKind) { case DriverKind::CPU: { // There should be only one driver -- the CPU - const char* name = DeviceManager::get()->getDrivers()[0]->getName(); + const std::string& name = DeviceManager::get()->getDrivers()[0]->getName(); ASSERT_TRUE(selectDeviceByName(name)); break; } diff --git a/nn/runtime/test/TestPartitioning.cpp b/nn/runtime/test/TestPartitioning.cpp index b01e58015..2c774a6e8 100644 --- a/nn/runtime/test/TestPartitioning.cpp +++ b/nn/runtime/test/TestPartitioning.cpp @@ -1254,7 +1254,7 @@ TEST_F(PartitioningTest, SimpleModel) { ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(planA.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); ASSERT_NE(planA.forTest_simpleGetDevice().get(), nullptr); - ASSERT_STREQ(planA.forTest_simpleGetDevice()->getName(), "good"); + ASSERT_EQ(planA.forTest_simpleGetDevice()->getName(), "good"); // Simple partition (two devices are each capable of everything, none better than CPU). // No need to compare the original model to the model from the plan -- we @@ -1342,7 +1342,7 @@ TEST_F(PartitioningTest, SliceModel) { ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(planA.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); ASSERT_NE(planA.forTest_simpleGetDevice().get(), nullptr); - ASSERT_STREQ(planA.forTest_simpleGetDevice()->getName(), "V1_2"); + ASSERT_EQ(planA.forTest_simpleGetDevice()->getName(), "V1_2"); // Compound partition (V1_0, V1_1, V1_2 devices are available, in decreasing // order of performance; model is distributed across all three devices). @@ -1442,7 +1442,7 @@ TEST_F(PartitioningTest, SliceModelToEmpty) { ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); ASSERT_NE(plan.forTest_simpleGetDevice().get(), nullptr); - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), "V1_2"); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), "V1_2"); } TEST_F(PartitioningTest, Cpu) { @@ -1682,7 +1682,7 @@ TEST_F(PartitioningTest, OemOperations) { const auto& planBestOEM = compilationBestOEM.getExecutionPlan(); ASSERT_EQ(planBestOEM.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); ASSERT_NE(planBestOEM.forTest_simpleGetDevice().get(), nullptr); - ASSERT_STREQ(planBestOEM.forTest_simpleGetDevice()->getName(), "goodOEM"); + ASSERT_EQ(planBestOEM.forTest_simpleGetDevice()->getName(), "goodOEM"); // Verify that we get an error if no driver can run an OEM operation. const auto devicesNoOEM = makeDevices({{"noOEM", 0.5, ~0U, PartitioningDriver::OEMNo}}); @@ -1724,7 +1724,7 @@ TEST_F(PartitioningTest, RelaxedFP) { ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan), ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), expectDevice); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), expectDevice); }; ASSERT_NO_FATAL_FAILURE(TrivialTest(false, "f32")); @@ -1772,7 +1772,7 @@ TEST_F(PartitioningTest, Perf) { ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan), ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), "good"); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), "good"); } { @@ -1790,7 +1790,7 @@ TEST_F(PartitioningTest, Perf) { ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan), ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), "base"); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), "base"); } }; @@ -1853,7 +1853,7 @@ class CacheTest : public PartitioningTest { // Find the cache info for the device. const uint8_t* token = nullptr; if (plan.forTest_getKind() == ExecutionPlan::Kind::SIMPLE) { - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), deviceName); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), deviceName); token = plan.forTest_simpleGetCacheToken(); } else if (plan.forTest_getKind() == ExecutionPlan::Kind::COMPOUND) { const auto& steps = plan.forTest_compoundGetSteps(); @@ -1861,7 +1861,7 @@ class CacheTest : public PartitioningTest { for (const auto& step : steps) { // In general, two or more partitions can be on the same device. However, this will // not happen on the test models with only 2 operations. - if (strcmp(step->getDevice()->getName(), deviceName) == 0) { + if (step->getDevice()->getName() == deviceName) { ASSERT_FALSE(found); token = step->forTest_getCacheToken(); found = true; |