summaryrefslogtreecommitdiff
path: root/nn/runtime/VersionedInterfaces.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/runtime/VersionedInterfaces.cpp')
-rw-r--r--nn/runtime/VersionedInterfaces.cpp121
1 files changed, 110 insertions, 11 deletions
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index ba6e2af7c..8da4e7c1f 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -306,21 +306,88 @@ std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExec
std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
sp<V1_0::IDevice> device) {
+ CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object.";
+
auto core = Core::create(std::move(device));
if (!core.has_value()) {
LOG(ERROR) << "VersionedIDevice::create failed to create Core.";
return nullptr;
}
- // return a valid VersionedIDevice object
- return std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value()));
+ // create and initialize a VersionedIDevice object
+ const auto versionedIDevice =
+ std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value()));
+ if (!versionedIDevice->initializeInternal()) {
+ LOG(ERROR) << "VersionedIDevice failed to initialize";
+ return nullptr;
+ }
+
+ // return a valid, initialized VersionedIDevice object
+ return versionedIDevice;
}
VersionedIDevice::VersionedIDevice(std::string serviceName, Core core)
: mServiceName(std::move(serviceName)), mCore(std::move(core)) {}
+bool VersionedIDevice::initializeInternal() {
+ auto [capabilitiesStatus, capabilities] = getCapabilitiesInternal();
+ if (capabilitiesStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IDevice::getCapabilities* returned the error "
+ << toString(capabilitiesStatus);
+ return false;
+ }
+ VLOG(MANAGER) << "Capab " << toString(capabilities);
+
+ const auto [versionStatus, versionString] = getVersionStringInternal();
+ // TODO(miaowang): add a validation test case for in case of error.
+ if (versionStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(versionStatus);
+ return false;
+ }
+
+ const int32_t type = getTypeInternal();
+ if (type == -1) {
+ LOG(ERROR) << "IDevice::getType returned an error";
+ return false;
+ }
+
+ const auto [extensionsStatus, supportedExtensions] = getSupportedExtensionsInternal();
+ if (extensionsStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IDevice::getSupportedExtensions returned the error "
+ << toString(extensionsStatus);
+ return false;
+ }
+
+ const auto [cacheFilesStatus, numModelCacheFiles, numDataCacheFiles] =
+ getNumberOfCacheFilesNeededInternal();
+ if (cacheFilesStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IDevice::getNumberOfCacheFilesNeeded returned the error "
+ << toString(cacheFilesStatus);
+ return false;
+ }
+
+ // The following limit is enforced by VTS
+ constexpr uint32_t maxNumCacheFiles =
+ static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES);
+ if (numModelCacheFiles > maxNumCacheFiles || numDataCacheFiles > maxNumCacheFiles) {
+ LOG(ERROR)
+ << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files: "
+ "numModelCacheFiles = "
+ << numModelCacheFiles << ", numDataCacheFiles = " << numDataCacheFiles
+ << ", maxNumCacheFiles = " << maxNumCacheFiles;
+ return false;
+ }
+
+ // set internal members
+ mCapabilities = std::move(capabilities);
+ mSupportedExtensions = supportedExtensions;
+ mType = type;
+ mVersionString = versionString;
+ mNumberOfCacheFilesNeeded = {numModelCacheFiles, numDataCacheFiles};
+ return true;
+}
+
std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
- // verify input
CHECK(device != nullptr) << "VersionedIDevice::Core::create passed invalid device object.";
// create death handler object
@@ -480,7 +547,7 @@ Return<T_Return> VersionedIDevice::recoverable(
return ret;
}
-std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() const {
+std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilitiesInternal() const {
const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
std::pair<ErrorStatus, Capabilities> result;
@@ -559,7 +626,12 @@ std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() const {
return {ErrorStatus::DEVICE_UNAVAILABLE, {}};
}
-std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() const {
+const Capabilities& VersionedIDevice::getCapabilities() const {
+ return mCapabilities;
+}
+
+std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensionsInternal()
+ const {
const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
// version 1.2+ HAL
@@ -590,6 +662,10 @@ std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtens
return {ErrorStatus::DEVICE_UNAVAILABLE, {}};
}
+const std::vector<Extension>& VersionedIDevice::getSupportedExtensions() const {
+ return mSupportedExtensions;
+}
+
std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
const MetaModel& metaModel) const {
const std::pair<ErrorStatus, hidl_vec<bool>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
@@ -967,7 +1043,7 @@ int64_t VersionedIDevice::getFeatureLevel() const {
}
}
-int32_t VersionedIDevice::getType() const {
+int32_t VersionedIDevice::getTypeInternal() const {
constexpr int32_t kFailure = -1;
// version 1.2+ HAL
@@ -988,12 +1064,22 @@ int32_t VersionedIDevice::getType() const {
return result;
}
- // version too low or no device available
- LOG(INFO) << "Unknown NNAPI device type.";
- return ANEURALNETWORKS_DEVICE_UNKNOWN;
+ // version too low
+ if (getDevice<V1_0::IDevice>() != nullptr) {
+ LOG(INFO) << "Unknown NNAPI device type.";
+ return ANEURALNETWORKS_DEVICE_UNKNOWN;
+ }
+
+ // No device available
+ LOG(ERROR) << "Could not handle getType";
+ return kFailure;
+}
+
+int32_t VersionedIDevice::getType() const {
+ return mType;
}
-std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() const {
+std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionStringInternal() const {
const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""};
// version 1.2+ HAL
@@ -1023,7 +1109,12 @@ std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() const {
return kFailure;
}
-std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const {
+const std::string& VersionedIDevice::getVersionString() const {
+ return mVersionString;
+}
+
+std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeededInternal()
+ const {
constexpr std::tuple<ErrorStatus, uint32_t, uint32_t> kFailure = {ErrorStatus::GENERAL_FAILURE,
0, 0};
@@ -1055,5 +1146,13 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi
return kFailure;
}
+std::pair<uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const {
+ return mNumberOfCacheFilesNeeded;
+}
+
+const std::string& VersionedIDevice::getName() const {
+ return mServiceName;
+}
+
} // namespace nn
} // namespace android