diff options
Diffstat (limited to 'nn/runtime/VersionedInterfaces.cpp')
-rw-r--r-- | nn/runtime/VersionedInterfaces.cpp | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index cd39a52a7..3ae950eac 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -703,8 +703,16 @@ std::optional<InitialData> initialize(const Core& core) { } std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName, - sp<V1_0::IDevice> device) { - CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object."; + const DeviceFactory& makeDevice) { + CHECK(makeDevice != nullptr) + << "VersionedIDevice::create passed invalid device factory object."; + + // get handle to IDevice object + sp<V1_0::IDevice> device = makeDevice(/*blocking=*/true); + if (device == nullptr) { + VLOG(DRIVER) << "VersionedIDevice::create got a null IDevice for " << serviceName; + return nullptr; + } auto core = Core::create(std::move(device)); if (!core.has_value()) { @@ -722,20 +730,22 @@ std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceNa std::move(*initialData); return std::make_shared<VersionedIDevice>( std::move(capabilities), std::move(supportedExtensions), type, std::move(versionString), - numberOfCacheFilesNeeded, std::move(serviceName), std::move(core.value())); + numberOfCacheFilesNeeded, std::move(serviceName), makeDevice, std::move(core.value())); } VersionedIDevice::VersionedIDevice(hal::Capabilities capabilities, std::vector<hal::Extension> supportedExtensions, int32_t type, std::string versionString, std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded, - std::string serviceName, Core core) + std::string serviceName, const DeviceFactory& makeDevice, + Core core) : kCapabilities(std::move(capabilities)), kSupportedExtensions(std::move(supportedExtensions)), kType(type), kVersionString(std::move(versionString)), kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded), kServiceName(std::move(serviceName)), + kMakeDevice(makeDevice), mCore(std::move(core)) {} std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) { @@ -874,7 +884,7 @@ Return<T_Return> VersionedIDevice::recoverable( if (pingReturn.isDeadObject()) { VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering " << kServiceName; - sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(kServiceName); + sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/false); if (recoveredDevice == nullptr) { VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for " << kServiceName; @@ -911,7 +921,7 @@ int VersionedIDevice::wait() const { auto pingReturn = mCore.getDevice<V1_0::IDevice>()->ping(); if (pingReturn.isDeadObject()) { VLOG(DRIVER) << "VersionedIDevice::wait -- Recovering " << kServiceName; - sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::getService(kServiceName); + sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/true); if (recoveredDevice == nullptr) { LOG(ERROR) << "VersionedIDevice::wait got a null IDevice for " << kServiceName; return ANEURALNETWORKS_OP_FAILED; |