diff options
-rw-r--r-- | nn/common/include/HalInterfaces.h | 1 | ||||
-rw-r--r-- | nn/runtime/Manager.cpp | 37 | ||||
-rw-r--r-- | nn/runtime/Manager.h | 5 | ||||
-rw-r--r-- | nn/runtime/Memory.cpp | 2 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.cpp | 22 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.h | 13 |
6 files changed, 49 insertions, 31 deletions
diff --git a/nn/common/include/HalInterfaces.h b/nn/common/include/HalInterfaces.h index fe1ff563e..4e3a3800b 100644 --- a/nn/common/include/HalInterfaces.h +++ b/nn/common/include/HalInterfaces.h @@ -103,6 +103,7 @@ using OperandExtraParams = V1_2::Operand::ExtraParams; using CacheToken = hardware::hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>; +using DeviceFactory = std::function<sp<V1_0::IDevice>(bool blocking)>; using ModelFactory = std::function<Model()>; inline constexpr Priority kDefaultPriority = Priority::MEDIUM; diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp index 7be8419ad..310710e3c 100644 --- a/nn/runtime/Manager.cpp +++ b/nn/runtime/Manager.cpp @@ -55,9 +55,10 @@ const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX // A Device with actual underlying driver class DriverDevice : public Device { public: - // Create a DriverDevice from a name and an IDevice. + // Create a DriverDevice from a name and a DeviceFactory function. // Returns nullptr on failure. - static std::shared_ptr<DriverDevice> create(std::string name, sp<V1_0::IDevice> device); + static std::shared_ptr<DriverDevice> create(const std::string& name, + const DeviceFactory& makeDevice); // Prefer using DriverDevice::create DriverDevice(std::shared_ptr<VersionedIDevice> device); @@ -159,6 +160,7 @@ class DriverPreparedModel : public PreparedModel { DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device) : kInterface(std::move(device)) { + CHECK(kInterface != nullptr); #ifdef NN_DEBUGGABLE static const char samplePrefix[] = "sample"; if (getName().substr(0, sizeof(samplePrefix) - 1) == samplePrefix) { @@ -167,17 +169,17 @@ DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device) #endif // NN_DEBUGGABLE } -std::shared_ptr<DriverDevice> DriverDevice::create(std::string name, sp<V1_0::IDevice> device) { - CHECK(device != nullptr); - std::shared_ptr<VersionedIDevice> versionedDevice = - VersionedIDevice::create(name, std::move(device)); - if (versionedDevice == nullptr) { +std::shared_ptr<DriverDevice> DriverDevice::create(const std::string& name, + const DeviceFactory& makeDevice) { + CHECK(makeDevice != nullptr); + std::shared_ptr<VersionedIDevice> device = VersionedIDevice::create(name, makeDevice); + if (device == nullptr) { LOG(ERROR) << "DriverDevice::create failed to create VersionedIDevice object for service " << name; return nullptr; } - return std::make_shared<DriverDevice>(std::move(versionedDevice)); + return std::make_shared<DriverDevice>(std::move(device)); } std::vector<bool> DriverDevice::getSupportedOperations(const MetaModel& metaModel) const { @@ -817,7 +819,8 @@ std::shared_ptr<Device> DeviceManager::getCpuDevice() { std::shared_ptr<Device> DeviceManager::forTest_makeDriverDevice(const std::string& name, const sp<V1_0::IDevice>& device) { - const auto driverDevice = DriverDevice::create(name, device); + const DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; }; + const auto driverDevice = DriverDevice::create(name, makeDevice); CHECK(driverDevice != nullptr); return driverDevice; } @@ -829,12 +832,10 @@ void DeviceManager::findAvailableDevices() { const auto names = hardware::getAllHalInstanceNames(V1_0::IDevice::descriptor); for (const auto& name : names) { VLOG(MANAGER) << "Found interface " << name; - sp<V1_0::IDevice> device = V1_0::IDevice::getService(name); - if (device == nullptr) { - LOG(ERROR) << "Got a null IDEVICE for " << name; - continue; - } - registerDevice(name, device); + const DeviceFactory makeDevice = [name](bool blocking) { + return blocking ? V1_0::IDevice::getService(name) : V1_0::IDevice::tryGetService(name); + }; + registerDevice(name, makeDevice); } // register CPU fallback device @@ -842,9 +843,9 @@ void DeviceManager::findAvailableDevices() { mDevicesCpuOnly.push_back(CpuDevice::get()); } -void DeviceManager::registerDevice(const std::string& name, const sp<V1_0::IDevice>& device) { - if (const auto d = DriverDevice::create(name, device)) { - mDevices.push_back(d); +void DeviceManager::registerDevice(const std::string& name, const DeviceFactory& makeDevice) { + if (auto device = DriverDevice::create(name, makeDevice)) { + mDevices.push_back(std::move(device)); } } diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h index c28ee49a6..d6d483576 100644 --- a/nn/runtime/Manager.h +++ b/nn/runtime/Manager.h @@ -169,7 +169,8 @@ class DeviceManager { // Register a test device. void forTest_registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device) { - registerDevice(name, device); + const hal::DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; }; + registerDevice(name, makeDevice); } // Re-initialize the list of available devices. @@ -192,7 +193,7 @@ class DeviceManager { DeviceManager(); // Adds a device for the manager to use. - void registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device); + void registerDevice(const std::string& name, const hal::DeviceFactory& makeDevice); void findAvailableDevices(); diff --git a/nn/runtime/Memory.cpp b/nn/runtime/Memory.cpp index 7bfaf5562..e0bd6b953 100644 --- a/nn/runtime/Memory.cpp +++ b/nn/runtime/Memory.cpp @@ -194,7 +194,7 @@ Memory::Memory(sp<hal::IBuffer> buffer, uint32_t token) : kBuffer(std::move(buffer)), kToken(token) {} Memory::~Memory() { - for (const auto [ptr, weakBurst] : mUsedBy) { + for (const auto& [ptr, weakBurst] : mUsedBy) { if (const std::shared_ptr<ExecutionBurstController> burst = weakBurst.lock()) { burst->freeMemory(getKey()); } diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index cd39a52a7..3ae950eac 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -703,8 +703,16 @@ std::optional<InitialData> initialize(const Core& core) { } std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName, - sp<V1_0::IDevice> device) { - CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object."; + const DeviceFactory& makeDevice) { + CHECK(makeDevice != nullptr) + << "VersionedIDevice::create passed invalid device factory object."; + + // get handle to IDevice object + sp<V1_0::IDevice> device = makeDevice(/*blocking=*/true); + if (device == nullptr) { + VLOG(DRIVER) << "VersionedIDevice::create got a null IDevice for " << serviceName; + return nullptr; + } auto core = Core::create(std::move(device)); if (!core.has_value()) { @@ -722,20 +730,22 @@ std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceNa std::move(*initialData); return std::make_shared<VersionedIDevice>( std::move(capabilities), std::move(supportedExtensions), type, std::move(versionString), - numberOfCacheFilesNeeded, std::move(serviceName), std::move(core.value())); + numberOfCacheFilesNeeded, std::move(serviceName), makeDevice, std::move(core.value())); } VersionedIDevice::VersionedIDevice(hal::Capabilities capabilities, std::vector<hal::Extension> supportedExtensions, int32_t type, std::string versionString, std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded, - std::string serviceName, Core core) + std::string serviceName, const DeviceFactory& makeDevice, + Core core) : kCapabilities(std::move(capabilities)), kSupportedExtensions(std::move(supportedExtensions)), kType(type), kVersionString(std::move(versionString)), kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded), kServiceName(std::move(serviceName)), + kMakeDevice(makeDevice), mCore(std::move(core)) {} std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) { @@ -874,7 +884,7 @@ Return<T_Return> VersionedIDevice::recoverable( if (pingReturn.isDeadObject()) { VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering " << kServiceName; - sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(kServiceName); + sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/false); if (recoveredDevice == nullptr) { VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for " << kServiceName; @@ -911,7 +921,7 @@ int VersionedIDevice::wait() const { auto pingReturn = mCore.getDevice<V1_0::IDevice>()->ping(); if (pingReturn.isDeadObject()) { VLOG(DRIVER) << "VersionedIDevice::wait -- Recovering " << kServiceName; - sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::getService(kServiceName); + sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/true); if (recoveredDevice == nullptr) { LOG(ERROR) << "VersionedIDevice::wait got a null IDevice for " << kServiceName; return ANEURALNETWORKS_OP_FAILED; diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h index 94ca3fe49..efde0bdf5 100644 --- a/nn/runtime/VersionedInterfaces.h +++ b/nn/runtime/VersionedInterfaces.h @@ -72,12 +72,12 @@ class VersionedIDevice { * protections. * * @param serviceName The name of the service that provides "device". - * @param device A device object that is at least version 1.0 of the IDevice - * interface. + * @param makeDevice A device factory function that returns a device object + * that is at least version 1.0 of the IDevice interface. * @return A valid VersionedIDevice object, otherwise nullptr. */ static std::shared_ptr<VersionedIDevice> create(std::string serviceName, - sp<hal::V1_0::IDevice> device); + const hal::DeviceFactory& makeDevice); /** * Constructor for the VersionedIDevice object. @@ -92,6 +92,8 @@ class VersionedIDevice { * @param numberOfCacheFilesNeeded Number of model cache and data cache * files needed by the driver. * @param serviceName The name of the service that provides core.getDevice<V1_0::IDevice>(). + * @param makeDevice A device factory function that returns a device object + * that is at least version 1.0 of the IDevice interface. * @param core An object that encapsulates a V1_0::IDevice, any appropriate downcasts to * newer interfaces, and a hidl_death_recipient that will proactively handle * the case when the service containing the IDevice object crashes. @@ -100,7 +102,7 @@ class VersionedIDevice { std::vector<hal::Extension> supportedExtensions, int32_t type, std::string versionString, std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded, - std::string serviceName, Core core); + std::string serviceName, const hal::DeviceFactory& makeDevice, Core core); /** * Gets the capabilities of a driver. @@ -554,6 +556,9 @@ class VersionedIDevice { // The name of the service that implements the driver. const std::string kServiceName; + // Factory function object to generate an IDevice object. + const hal::DeviceFactory kMakeDevice; + // Guards access to mCore. mutable std::shared_mutex mMutex; |