diff options
Diffstat (limited to 'nn/runtime/Manager.cpp')
-rw-r--r-- | nn/runtime/Manager.cpp | 37 |
1 files changed, 19 insertions, 18 deletions
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp index 7be8419ad..310710e3c 100644 --- a/nn/runtime/Manager.cpp +++ b/nn/runtime/Manager.cpp @@ -55,9 +55,10 @@ const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX // A Device with actual underlying driver class DriverDevice : public Device { public: - // Create a DriverDevice from a name and an IDevice. + // Create a DriverDevice from a name and a DeviceFactory function. // Returns nullptr on failure. - static std::shared_ptr<DriverDevice> create(std::string name, sp<V1_0::IDevice> device); + static std::shared_ptr<DriverDevice> create(const std::string& name, + const DeviceFactory& makeDevice); // Prefer using DriverDevice::create DriverDevice(std::shared_ptr<VersionedIDevice> device); @@ -159,6 +160,7 @@ class DriverPreparedModel : public PreparedModel { DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device) : kInterface(std::move(device)) { + CHECK(kInterface != nullptr); #ifdef NN_DEBUGGABLE static const char samplePrefix[] = "sample"; if (getName().substr(0, sizeof(samplePrefix) - 1) == samplePrefix) { @@ -167,17 +169,17 @@ DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device) #endif // NN_DEBUGGABLE } -std::shared_ptr<DriverDevice> DriverDevice::create(std::string name, sp<V1_0::IDevice> device) { - CHECK(device != nullptr); - std::shared_ptr<VersionedIDevice> versionedDevice = - VersionedIDevice::create(name, std::move(device)); - if (versionedDevice == nullptr) { +std::shared_ptr<DriverDevice> DriverDevice::create(const std::string& name, + const DeviceFactory& makeDevice) { + CHECK(makeDevice != nullptr); + std::shared_ptr<VersionedIDevice> device = VersionedIDevice::create(name, makeDevice); + if (device == nullptr) { LOG(ERROR) << "DriverDevice::create failed to create VersionedIDevice object for service " << name; return nullptr; } - return std::make_shared<DriverDevice>(std::move(versionedDevice)); + return std::make_shared<DriverDevice>(std::move(device)); } std::vector<bool> DriverDevice::getSupportedOperations(const MetaModel& metaModel) const { @@ -817,7 +819,8 @@ std::shared_ptr<Device> DeviceManager::getCpuDevice() { std::shared_ptr<Device> DeviceManager::forTest_makeDriverDevice(const std::string& name, const sp<V1_0::IDevice>& device) { - const auto driverDevice = DriverDevice::create(name, device); + const DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; }; + const auto driverDevice = DriverDevice::create(name, makeDevice); CHECK(driverDevice != nullptr); return driverDevice; } @@ -829,12 +832,10 @@ void DeviceManager::findAvailableDevices() { const auto names = hardware::getAllHalInstanceNames(V1_0::IDevice::descriptor); for (const auto& name : names) { VLOG(MANAGER) << "Found interface " << name; - sp<V1_0::IDevice> device = V1_0::IDevice::getService(name); - if (device == nullptr) { - LOG(ERROR) << "Got a null IDEVICE for " << name; - continue; - } - registerDevice(name, device); + const DeviceFactory makeDevice = [name](bool blocking) { + return blocking ? V1_0::IDevice::getService(name) : V1_0::IDevice::tryGetService(name); + }; + registerDevice(name, makeDevice); } // register CPU fallback device @@ -842,9 +843,9 @@ void DeviceManager::findAvailableDevices() { mDevicesCpuOnly.push_back(CpuDevice::get()); } -void DeviceManager::registerDevice(const std::string& name, const sp<V1_0::IDevice>& device) { - if (const auto d = DriverDevice::create(name, device)) { - mDevices.push_back(d); +void DeviceManager::registerDevice(const std::string& name, const DeviceFactory& makeDevice) { + if (auto device = DriverDevice::create(name, makeDevice)) { + mDevices.push_back(std::move(device)); } } |