diff options
Diffstat (limited to 'nn/runtime/VersionedInterfaces.cpp')
-rw-r--r-- | nn/runtime/VersionedInterfaces.cpp | 121 |
1 files changed, 110 insertions, 11 deletions
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index ba6e2af7c..8da4e7c1f 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -306,21 +306,88 @@ std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExec std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName, sp<V1_0::IDevice> device) { + CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object."; + auto core = Core::create(std::move(device)); if (!core.has_value()) { LOG(ERROR) << "VersionedIDevice::create failed to create Core."; return nullptr; } - // return a valid VersionedIDevice object - return std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value())); + // create and initialize a VersionedIDevice object + const auto versionedIDevice = + std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value())); + if (!versionedIDevice->initializeInternal()) { + LOG(ERROR) << "VersionedIDevice failed to initialize"; + return nullptr; + } + + // return a valid, initialized VersionedIDevice object + return versionedIDevice; } VersionedIDevice::VersionedIDevice(std::string serviceName, Core core) : mServiceName(std::move(serviceName)), mCore(std::move(core)) {} +bool VersionedIDevice::initializeInternal() { + auto [capabilitiesStatus, capabilities] = getCapabilitiesInternal(); + if (capabilitiesStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getCapabilities* returned the error " + << toString(capabilitiesStatus); + return false; + } + VLOG(MANAGER) << "Capab " << toString(capabilities); + + const auto [versionStatus, versionString] = getVersionStringInternal(); + // TODO(miaowang): add a validation test case for in case of error. + if (versionStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(versionStatus); + return false; + } + + const int32_t type = getTypeInternal(); + if (type == -1) { + LOG(ERROR) << "IDevice::getType returned an error"; + return false; + } + + const auto [extensionsStatus, supportedExtensions] = getSupportedExtensionsInternal(); + if (extensionsStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getSupportedExtensions returned the error " + << toString(extensionsStatus); + return false; + } + + const auto [cacheFilesStatus, numModelCacheFiles, numDataCacheFiles] = + getNumberOfCacheFilesNeededInternal(); + if (cacheFilesStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getNumberOfCacheFilesNeeded returned the error " + << toString(cacheFilesStatus); + return false; + } + + // The following limit is enforced by VTS + constexpr uint32_t maxNumCacheFiles = + static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES); + if (numModelCacheFiles > maxNumCacheFiles || numDataCacheFiles > maxNumCacheFiles) { + LOG(ERROR) + << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files: " + "numModelCacheFiles = " + << numModelCacheFiles << ", numDataCacheFiles = " << numDataCacheFiles + << ", maxNumCacheFiles = " << maxNumCacheFiles; + return false; + } + + // set internal members + mCapabilities = std::move(capabilities); + mSupportedExtensions = supportedExtensions; + mType = type; + mVersionString = versionString; + mNumberOfCacheFilesNeeded = {numModelCacheFiles, numDataCacheFiles}; + return true; +} + std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) { - // verify input CHECK(device != nullptr) << "VersionedIDevice::Core::create passed invalid device object."; // create death handler object @@ -480,7 +547,7 @@ Return<T_Return> VersionedIDevice::recoverable( return ret; } -std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() const { +std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilitiesInternal() const { const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}}; std::pair<ErrorStatus, Capabilities> result; @@ -559,7 +626,12 @@ std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() const { return {ErrorStatus::DEVICE_UNAVAILABLE, {}}; } -std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() const { +const Capabilities& VersionedIDevice::getCapabilities() const { + return mCapabilities; +} + +std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensionsInternal() + const { const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}}; // version 1.2+ HAL @@ -590,6 +662,10 @@ std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtens return {ErrorStatus::DEVICE_UNAVAILABLE, {}}; } +const std::vector<Extension>& VersionedIDevice::getSupportedExtensions() const { + return mSupportedExtensions; +} + std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations( const MetaModel& metaModel) const { const std::pair<ErrorStatus, hidl_vec<bool>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}}; @@ -967,7 +1043,7 @@ int64_t VersionedIDevice::getFeatureLevel() const { } } -int32_t VersionedIDevice::getType() const { +int32_t VersionedIDevice::getTypeInternal() const { constexpr int32_t kFailure = -1; // version 1.2+ HAL @@ -988,12 +1064,22 @@ int32_t VersionedIDevice::getType() const { return result; } - // version too low or no device available - LOG(INFO) << "Unknown NNAPI device type."; - return ANEURALNETWORKS_DEVICE_UNKNOWN; + // version too low + if (getDevice<V1_0::IDevice>() != nullptr) { + LOG(INFO) << "Unknown NNAPI device type."; + return ANEURALNETWORKS_DEVICE_UNKNOWN; + } + + // No device available + LOG(ERROR) << "Could not handle getType"; + return kFailure; +} + +int32_t VersionedIDevice::getType() const { + return mType; } -std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() const { +std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionStringInternal() const { const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""}; // version 1.2+ HAL @@ -1023,7 +1109,12 @@ std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() const { return kFailure; } -std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const { +const std::string& VersionedIDevice::getVersionString() const { + return mVersionString; +} + +std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeededInternal() + const { constexpr std::tuple<ErrorStatus, uint32_t, uint32_t> kFailure = {ErrorStatus::GENERAL_FAILURE, 0, 0}; @@ -1055,5 +1146,13 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi return kFailure; } +std::pair<uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const { + return mNumberOfCacheFilesNeeded; +} + +const std::string& VersionedIDevice::getName() const { + return mServiceName; +} + } // namespace nn } // namespace android |