summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nn/common/include/HalInterfaces.h1
-rw-r--r--nn/runtime/Manager.cpp37
-rw-r--r--nn/runtime/Manager.h5
-rw-r--r--nn/runtime/Memory.cpp2
-rw-r--r--nn/runtime/VersionedInterfaces.cpp22
-rw-r--r--nn/runtime/VersionedInterfaces.h13
6 files changed, 49 insertions, 31 deletions
diff --git a/nn/common/include/HalInterfaces.h b/nn/common/include/HalInterfaces.h
index fe1ff563e..4e3a3800b 100644
--- a/nn/common/include/HalInterfaces.h
+++ b/nn/common/include/HalInterfaces.h
@@ -103,6 +103,7 @@ using OperandExtraParams = V1_2::Operand::ExtraParams;
using CacheToken =
hardware::hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
+using DeviceFactory = std::function<sp<V1_0::IDevice>(bool blocking)>;
using ModelFactory = std::function<Model()>;
inline constexpr Priority kDefaultPriority = Priority::MEDIUM;
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index 7be8419ad..310710e3c 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -55,9 +55,10 @@ const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX
// A Device with actual underlying driver
class DriverDevice : public Device {
public:
- // Create a DriverDevice from a name and an IDevice.
+ // Create a DriverDevice from a name and a DeviceFactory function.
// Returns nullptr on failure.
- static std::shared_ptr<DriverDevice> create(std::string name, sp<V1_0::IDevice> device);
+ static std::shared_ptr<DriverDevice> create(const std::string& name,
+ const DeviceFactory& makeDevice);
// Prefer using DriverDevice::create
DriverDevice(std::shared_ptr<VersionedIDevice> device);
@@ -159,6 +160,7 @@ class DriverPreparedModel : public PreparedModel {
DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device)
: kInterface(std::move(device)) {
+ CHECK(kInterface != nullptr);
#ifdef NN_DEBUGGABLE
static const char samplePrefix[] = "sample";
if (getName().substr(0, sizeof(samplePrefix) - 1) == samplePrefix) {
@@ -167,17 +169,17 @@ DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device)
#endif // NN_DEBUGGABLE
}
-std::shared_ptr<DriverDevice> DriverDevice::create(std::string name, sp<V1_0::IDevice> device) {
- CHECK(device != nullptr);
- std::shared_ptr<VersionedIDevice> versionedDevice =
- VersionedIDevice::create(name, std::move(device));
- if (versionedDevice == nullptr) {
+std::shared_ptr<DriverDevice> DriverDevice::create(const std::string& name,
+ const DeviceFactory& makeDevice) {
+ CHECK(makeDevice != nullptr);
+ std::shared_ptr<VersionedIDevice> device = VersionedIDevice::create(name, makeDevice);
+ if (device == nullptr) {
LOG(ERROR) << "DriverDevice::create failed to create VersionedIDevice object for service "
<< name;
return nullptr;
}
- return std::make_shared<DriverDevice>(std::move(versionedDevice));
+ return std::make_shared<DriverDevice>(std::move(device));
}
std::vector<bool> DriverDevice::getSupportedOperations(const MetaModel& metaModel) const {
@@ -817,7 +819,8 @@ std::shared_ptr<Device> DeviceManager::getCpuDevice() {
std::shared_ptr<Device> DeviceManager::forTest_makeDriverDevice(const std::string& name,
const sp<V1_0::IDevice>& device) {
- const auto driverDevice = DriverDevice::create(name, device);
+ const DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; };
+ const auto driverDevice = DriverDevice::create(name, makeDevice);
CHECK(driverDevice != nullptr);
return driverDevice;
}
@@ -829,12 +832,10 @@ void DeviceManager::findAvailableDevices() {
const auto names = hardware::getAllHalInstanceNames(V1_0::IDevice::descriptor);
for (const auto& name : names) {
VLOG(MANAGER) << "Found interface " << name;
- sp<V1_0::IDevice> device = V1_0::IDevice::getService(name);
- if (device == nullptr) {
- LOG(ERROR) << "Got a null IDEVICE for " << name;
- continue;
- }
- registerDevice(name, device);
+ const DeviceFactory makeDevice = [name](bool blocking) {
+ return blocking ? V1_0::IDevice::getService(name) : V1_0::IDevice::tryGetService(name);
+ };
+ registerDevice(name, makeDevice);
}
// register CPU fallback device
@@ -842,9 +843,9 @@ void DeviceManager::findAvailableDevices() {
mDevicesCpuOnly.push_back(CpuDevice::get());
}
-void DeviceManager::registerDevice(const std::string& name, const sp<V1_0::IDevice>& device) {
- if (const auto d = DriverDevice::create(name, device)) {
- mDevices.push_back(d);
+void DeviceManager::registerDevice(const std::string& name, const DeviceFactory& makeDevice) {
+ if (auto device = DriverDevice::create(name, makeDevice)) {
+ mDevices.push_back(std::move(device));
}
}
diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h
index c28ee49a6..d6d483576 100644
--- a/nn/runtime/Manager.h
+++ b/nn/runtime/Manager.h
@@ -169,7 +169,8 @@ class DeviceManager {
// Register a test device.
void forTest_registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device) {
- registerDevice(name, device);
+ const hal::DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; };
+ registerDevice(name, makeDevice);
}
// Re-initialize the list of available devices.
@@ -192,7 +193,7 @@ class DeviceManager {
DeviceManager();
// Adds a device for the manager to use.
- void registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device);
+ void registerDevice(const std::string& name, const hal::DeviceFactory& makeDevice);
void findAvailableDevices();
diff --git a/nn/runtime/Memory.cpp b/nn/runtime/Memory.cpp
index 7bfaf5562..e0bd6b953 100644
--- a/nn/runtime/Memory.cpp
+++ b/nn/runtime/Memory.cpp
@@ -194,7 +194,7 @@ Memory::Memory(sp<hal::IBuffer> buffer, uint32_t token)
: kBuffer(std::move(buffer)), kToken(token) {}
Memory::~Memory() {
- for (const auto [ptr, weakBurst] : mUsedBy) {
+ for (const auto& [ptr, weakBurst] : mUsedBy) {
if (const std::shared_ptr<ExecutionBurstController> burst = weakBurst.lock()) {
burst->freeMemory(getKey());
}
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index cd39a52a7..3ae950eac 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -703,8 +703,16 @@ std::optional<InitialData> initialize(const Core& core) {
}
std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
- sp<V1_0::IDevice> device) {
- CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object.";
+ const DeviceFactory& makeDevice) {
+ CHECK(makeDevice != nullptr)
+ << "VersionedIDevice::create passed invalid device factory object.";
+
+ // get handle to IDevice object
+ sp<V1_0::IDevice> device = makeDevice(/*blocking=*/true);
+ if (device == nullptr) {
+ VLOG(DRIVER) << "VersionedIDevice::create got a null IDevice for " << serviceName;
+ return nullptr;
+ }
auto core = Core::create(std::move(device));
if (!core.has_value()) {
@@ -722,20 +730,22 @@ std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceNa
std::move(*initialData);
return std::make_shared<VersionedIDevice>(
std::move(capabilities), std::move(supportedExtensions), type, std::move(versionString),
- numberOfCacheFilesNeeded, std::move(serviceName), std::move(core.value()));
+ numberOfCacheFilesNeeded, std::move(serviceName), makeDevice, std::move(core.value()));
}
VersionedIDevice::VersionedIDevice(hal::Capabilities capabilities,
std::vector<hal::Extension> supportedExtensions, int32_t type,
std::string versionString,
std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
- std::string serviceName, Core core)
+ std::string serviceName, const DeviceFactory& makeDevice,
+ Core core)
: kCapabilities(std::move(capabilities)),
kSupportedExtensions(std::move(supportedExtensions)),
kType(type),
kVersionString(std::move(versionString)),
kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded),
kServiceName(std::move(serviceName)),
+ kMakeDevice(makeDevice),
mCore(std::move(core)) {}
std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
@@ -874,7 +884,7 @@ Return<T_Return> VersionedIDevice::recoverable(
if (pingReturn.isDeadObject()) {
VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering "
<< kServiceName;
- sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(kServiceName);
+ sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/false);
if (recoveredDevice == nullptr) {
VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for "
<< kServiceName;
@@ -911,7 +921,7 @@ int VersionedIDevice::wait() const {
auto pingReturn = mCore.getDevice<V1_0::IDevice>()->ping();
if (pingReturn.isDeadObject()) {
VLOG(DRIVER) << "VersionedIDevice::wait -- Recovering " << kServiceName;
- sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::getService(kServiceName);
+ sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/true);
if (recoveredDevice == nullptr) {
LOG(ERROR) << "VersionedIDevice::wait got a null IDevice for " << kServiceName;
return ANEURALNETWORKS_OP_FAILED;
diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h
index 94ca3fe49..efde0bdf5 100644
--- a/nn/runtime/VersionedInterfaces.h
+++ b/nn/runtime/VersionedInterfaces.h
@@ -72,12 +72,12 @@ class VersionedIDevice {
* protections.
*
* @param serviceName The name of the service that provides "device".
- * @param device A device object that is at least version 1.0 of the IDevice
- * interface.
+ * @param makeDevice A device factory function that returns a device object
+ * that is at least version 1.0 of the IDevice interface.
* @return A valid VersionedIDevice object, otherwise nullptr.
*/
static std::shared_ptr<VersionedIDevice> create(std::string serviceName,
- sp<hal::V1_0::IDevice> device);
+ const hal::DeviceFactory& makeDevice);
/**
* Constructor for the VersionedIDevice object.
@@ -92,6 +92,8 @@ class VersionedIDevice {
* @param numberOfCacheFilesNeeded Number of model cache and data cache
* files needed by the driver.
* @param serviceName The name of the service that provides core.getDevice<V1_0::IDevice>().
+ * @param makeDevice A device factory function that returns a device object
+ * that is at least version 1.0 of the IDevice interface.
* @param core An object that encapsulates a V1_0::IDevice, any appropriate downcasts to
* newer interfaces, and a hidl_death_recipient that will proactively handle
* the case when the service containing the IDevice object crashes.
@@ -100,7 +102,7 @@ class VersionedIDevice {
std::vector<hal::Extension> supportedExtensions, int32_t type,
std::string versionString,
std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
- std::string serviceName, Core core);
+ std::string serviceName, const hal::DeviceFactory& makeDevice, Core core);
/**
* Gets the capabilities of a driver.
@@ -554,6 +556,9 @@ class VersionedIDevice {
// The name of the service that implements the driver.
const std::string kServiceName;
+ // Factory function object to generate an IDevice object.
+ const hal::DeviceFactory kMakeDevice;
+
// Guards access to mCore.
mutable std::shared_mutex mMutex;