summaryrefslogtreecommitdiff
path: root/nn/runtime/Manager.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/runtime/Manager.cpp')
-rw-r--r--nn/runtime/Manager.cpp37
1 files changed, 19 insertions, 18 deletions
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index 7be8419ad..310710e3c 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -55,9 +55,10 @@ const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX
// A Device with actual underlying driver
class DriverDevice : public Device {
public:
- // Create a DriverDevice from a name and an IDevice.
+ // Create a DriverDevice from a name and a DeviceFactory function.
// Returns nullptr on failure.
- static std::shared_ptr<DriverDevice> create(std::string name, sp<V1_0::IDevice> device);
+ static std::shared_ptr<DriverDevice> create(const std::string& name,
+ const DeviceFactory& makeDevice);
// Prefer using DriverDevice::create
DriverDevice(std::shared_ptr<VersionedIDevice> device);
@@ -159,6 +160,7 @@ class DriverPreparedModel : public PreparedModel {
DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device)
: kInterface(std::move(device)) {
+ CHECK(kInterface != nullptr);
#ifdef NN_DEBUGGABLE
static const char samplePrefix[] = "sample";
if (getName().substr(0, sizeof(samplePrefix) - 1) == samplePrefix) {
@@ -167,17 +169,17 @@ DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device)
#endif // NN_DEBUGGABLE
}
-std::shared_ptr<DriverDevice> DriverDevice::create(std::string name, sp<V1_0::IDevice> device) {
- CHECK(device != nullptr);
- std::shared_ptr<VersionedIDevice> versionedDevice =
- VersionedIDevice::create(name, std::move(device));
- if (versionedDevice == nullptr) {
+std::shared_ptr<DriverDevice> DriverDevice::create(const std::string& name,
+ const DeviceFactory& makeDevice) {
+ CHECK(makeDevice != nullptr);
+ std::shared_ptr<VersionedIDevice> device = VersionedIDevice::create(name, makeDevice);
+ if (device == nullptr) {
LOG(ERROR) << "DriverDevice::create failed to create VersionedIDevice object for service "
<< name;
return nullptr;
}
- return std::make_shared<DriverDevice>(std::move(versionedDevice));
+ return std::make_shared<DriverDevice>(std::move(device));
}
std::vector<bool> DriverDevice::getSupportedOperations(const MetaModel& metaModel) const {
@@ -817,7 +819,8 @@ std::shared_ptr<Device> DeviceManager::getCpuDevice() {
std::shared_ptr<Device> DeviceManager::forTest_makeDriverDevice(const std::string& name,
const sp<V1_0::IDevice>& device) {
- const auto driverDevice = DriverDevice::create(name, device);
+ const DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; };
+ const auto driverDevice = DriverDevice::create(name, makeDevice);
CHECK(driverDevice != nullptr);
return driverDevice;
}
@@ -829,12 +832,10 @@ void DeviceManager::findAvailableDevices() {
const auto names = hardware::getAllHalInstanceNames(V1_0::IDevice::descriptor);
for (const auto& name : names) {
VLOG(MANAGER) << "Found interface " << name;
- sp<V1_0::IDevice> device = V1_0::IDevice::getService(name);
- if (device == nullptr) {
- LOG(ERROR) << "Got a null IDEVICE for " << name;
- continue;
- }
- registerDevice(name, device);
+ const DeviceFactory makeDevice = [name](bool blocking) {
+ return blocking ? V1_0::IDevice::getService(name) : V1_0::IDevice::tryGetService(name);
+ };
+ registerDevice(name, makeDevice);
}
// register CPU fallback device
@@ -842,9 +843,9 @@ void DeviceManager::findAvailableDevices() {
mDevicesCpuOnly.push_back(CpuDevice::get());
}
-void DeviceManager::registerDevice(const std::string& name, const sp<V1_0::IDevice>& device) {
- if (const auto d = DriverDevice::create(name, device)) {
- mDevices.push_back(d);
+void DeviceManager::registerDevice(const std::string& name, const DeviceFactory& makeDevice) {
+ if (auto device = DriverDevice::create(name, makeDevice)) {
+ mDevices.push_back(std::move(device));
}
}