summaryrefslogtreecommitdiff
path: root/nn
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2019-09-07 15:16:51 -0700
committerMichael Butler <butlermichael@google.com>2020-04-24 14:25:23 -0700
commit5591a18ea83233a6fd41366da16acdc3028cc7f1 (patch)
treecedd4adfd9edf717ebbbdb1c693c18fc159c494b /nn
parentaf55ce116ec438bf73c250ba8b9f64631f209608 (diff)
downloadml-5591a18ea83233a6fd41366da16acdc3028cc7f1.tar.gz
Simplify IDevice reboot logic
HIDL allows a service to be retrieved with two functions: * <Interface>::getService — blocks until service is retrieved * <Interface>::tryGetService — immediately returns service or nullptr Currently, the NN runtime retrieves the service in three different places: 1) When the runtime first starts, <Interface>::getService is used to acquire all services 2) When the object is dead, <Interface>::tryGetService is used to attempt to reacquire the service, but will quickly resume if the service is still rebooting 3) When the client calls ANNDevice_wait, <Interface>::getService is used to block until the service is active again This CL simplifies the IDevice reboot logic by changing these static class functions to dependency injection. Specifically, VersionedIDevice now retrieves a handle to the IDevice object through a DeviceFactory function that is passed in when the VersionedIDevice object is created. This function can either operate as blocking or nonblocking to support all use-cases described above, and makes it easier to test the VersionedIDevice recovery code. Bug: 139189546 Test: mma Test: NeuralNetworksTest_static Test: CtsNNAPITestCases Change-Id: I012eb4df6f09f98bdfbd0835457ba98bc22d906e
Diffstat (limited to 'nn')
-rw-r--r--nn/common/include/HalInterfaces.h1
-rw-r--r--nn/runtime/Manager.cpp37
-rw-r--r--nn/runtime/Manager.h5
-rw-r--r--nn/runtime/Memory.cpp2
-rw-r--r--nn/runtime/VersionedInterfaces.cpp22
-rw-r--r--nn/runtime/VersionedInterfaces.h13
6 files changed, 49 insertions, 31 deletions
diff --git a/nn/common/include/HalInterfaces.h b/nn/common/include/HalInterfaces.h
index fe1ff563e..4e3a3800b 100644
--- a/nn/common/include/HalInterfaces.h
+++ b/nn/common/include/HalInterfaces.h
@@ -103,6 +103,7 @@ using OperandExtraParams = V1_2::Operand::ExtraParams;
using CacheToken =
hardware::hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>;
+using DeviceFactory = std::function<sp<V1_0::IDevice>(bool blocking)>;
using ModelFactory = std::function<Model()>;
inline constexpr Priority kDefaultPriority = Priority::MEDIUM;
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index 7be8419ad..310710e3c 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -55,9 +55,10 @@ const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX
// A Device with actual underlying driver
class DriverDevice : public Device {
public:
- // Create a DriverDevice from a name and an IDevice.
+ // Create a DriverDevice from a name and a DeviceFactory function.
// Returns nullptr on failure.
- static std::shared_ptr<DriverDevice> create(std::string name, sp<V1_0::IDevice> device);
+ static std::shared_ptr<DriverDevice> create(const std::string& name,
+ const DeviceFactory& makeDevice);
// Prefer using DriverDevice::create
DriverDevice(std::shared_ptr<VersionedIDevice> device);
@@ -159,6 +160,7 @@ class DriverPreparedModel : public PreparedModel {
DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device)
: kInterface(std::move(device)) {
+ CHECK(kInterface != nullptr);
#ifdef NN_DEBUGGABLE
static const char samplePrefix[] = "sample";
if (getName().substr(0, sizeof(samplePrefix) - 1) == samplePrefix) {
@@ -167,17 +169,17 @@ DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device)
#endif // NN_DEBUGGABLE
}
-std::shared_ptr<DriverDevice> DriverDevice::create(std::string name, sp<V1_0::IDevice> device) {
- CHECK(device != nullptr);
- std::shared_ptr<VersionedIDevice> versionedDevice =
- VersionedIDevice::create(name, std::move(device));
- if (versionedDevice == nullptr) {
+std::shared_ptr<DriverDevice> DriverDevice::create(const std::string& name,
+ const DeviceFactory& makeDevice) {
+ CHECK(makeDevice != nullptr);
+ std::shared_ptr<VersionedIDevice> device = VersionedIDevice::create(name, makeDevice);
+ if (device == nullptr) {
LOG(ERROR) << "DriverDevice::create failed to create VersionedIDevice object for service "
<< name;
return nullptr;
}
- return std::make_shared<DriverDevice>(std::move(versionedDevice));
+ return std::make_shared<DriverDevice>(std::move(device));
}
std::vector<bool> DriverDevice::getSupportedOperations(const MetaModel& metaModel) const {
@@ -817,7 +819,8 @@ std::shared_ptr<Device> DeviceManager::getCpuDevice() {
std::shared_ptr<Device> DeviceManager::forTest_makeDriverDevice(const std::string& name,
const sp<V1_0::IDevice>& device) {
- const auto driverDevice = DriverDevice::create(name, device);
+ const DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; };
+ const auto driverDevice = DriverDevice::create(name, makeDevice);
CHECK(driverDevice != nullptr);
return driverDevice;
}
@@ -829,12 +832,10 @@ void DeviceManager::findAvailableDevices() {
const auto names = hardware::getAllHalInstanceNames(V1_0::IDevice::descriptor);
for (const auto& name : names) {
VLOG(MANAGER) << "Found interface " << name;
- sp<V1_0::IDevice> device = V1_0::IDevice::getService(name);
- if (device == nullptr) {
- LOG(ERROR) << "Got a null IDEVICE for " << name;
- continue;
- }
- registerDevice(name, device);
+ const DeviceFactory makeDevice = [name](bool blocking) {
+ return blocking ? V1_0::IDevice::getService(name) : V1_0::IDevice::tryGetService(name);
+ };
+ registerDevice(name, makeDevice);
}
// register CPU fallback device
@@ -842,9 +843,9 @@ void DeviceManager::findAvailableDevices() {
mDevicesCpuOnly.push_back(CpuDevice::get());
}
-void DeviceManager::registerDevice(const std::string& name, const sp<V1_0::IDevice>& device) {
- if (const auto d = DriverDevice::create(name, device)) {
- mDevices.push_back(d);
+void DeviceManager::registerDevice(const std::string& name, const DeviceFactory& makeDevice) {
+ if (auto device = DriverDevice::create(name, makeDevice)) {
+ mDevices.push_back(std::move(device));
}
}
diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h
index c28ee49a6..d6d483576 100644
--- a/nn/runtime/Manager.h
+++ b/nn/runtime/Manager.h
@@ -169,7 +169,8 @@ class DeviceManager {
// Register a test device.
void forTest_registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device) {
- registerDevice(name, device);
+ const hal::DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; };
+ registerDevice(name, makeDevice);
}
// Re-initialize the list of available devices.
@@ -192,7 +193,7 @@ class DeviceManager {
DeviceManager();
// Adds a device for the manager to use.
- void registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device);
+ void registerDevice(const std::string& name, const hal::DeviceFactory& makeDevice);
void findAvailableDevices();
diff --git a/nn/runtime/Memory.cpp b/nn/runtime/Memory.cpp
index 7bfaf5562..e0bd6b953 100644
--- a/nn/runtime/Memory.cpp
+++ b/nn/runtime/Memory.cpp
@@ -194,7 +194,7 @@ Memory::Memory(sp<hal::IBuffer> buffer, uint32_t token)
: kBuffer(std::move(buffer)), kToken(token) {}
Memory::~Memory() {
- for (const auto [ptr, weakBurst] : mUsedBy) {
+ for (const auto& [ptr, weakBurst] : mUsedBy) {
if (const std::shared_ptr<ExecutionBurstController> burst = weakBurst.lock()) {
burst->freeMemory(getKey());
}
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index cd39a52a7..3ae950eac 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -703,8 +703,16 @@ std::optional<InitialData> initialize(const Core& core) {
}
std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
- sp<V1_0::IDevice> device) {
- CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object.";
+ const DeviceFactory& makeDevice) {
+ CHECK(makeDevice != nullptr)
+ << "VersionedIDevice::create passed invalid device factory object.";
+
+ // get handle to IDevice object
+ sp<V1_0::IDevice> device = makeDevice(/*blocking=*/true);
+ if (device == nullptr) {
+ VLOG(DRIVER) << "VersionedIDevice::create got a null IDevice for " << serviceName;
+ return nullptr;
+ }
auto core = Core::create(std::move(device));
if (!core.has_value()) {
@@ -722,20 +730,22 @@ std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceNa
std::move(*initialData);
return std::make_shared<VersionedIDevice>(
std::move(capabilities), std::move(supportedExtensions), type, std::move(versionString),
- numberOfCacheFilesNeeded, std::move(serviceName), std::move(core.value()));
+ numberOfCacheFilesNeeded, std::move(serviceName), makeDevice, std::move(core.value()));
}
VersionedIDevice::VersionedIDevice(hal::Capabilities capabilities,
std::vector<hal::Extension> supportedExtensions, int32_t type,
std::string versionString,
std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
- std::string serviceName, Core core)
+ std::string serviceName, const DeviceFactory& makeDevice,
+ Core core)
: kCapabilities(std::move(capabilities)),
kSupportedExtensions(std::move(supportedExtensions)),
kType(type),
kVersionString(std::move(versionString)),
kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded),
kServiceName(std::move(serviceName)),
+ kMakeDevice(makeDevice),
mCore(std::move(core)) {}
std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
@@ -874,7 +884,7 @@ Return<T_Return> VersionedIDevice::recoverable(
if (pingReturn.isDeadObject()) {
VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering "
<< kServiceName;
- sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(kServiceName);
+ sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/false);
if (recoveredDevice == nullptr) {
VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for "
<< kServiceName;
@@ -911,7 +921,7 @@ int VersionedIDevice::wait() const {
auto pingReturn = mCore.getDevice<V1_0::IDevice>()->ping();
if (pingReturn.isDeadObject()) {
VLOG(DRIVER) << "VersionedIDevice::wait -- Recovering " << kServiceName;
- sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::getService(kServiceName);
+ sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/true);
if (recoveredDevice == nullptr) {
LOG(ERROR) << "VersionedIDevice::wait got a null IDevice for " << kServiceName;
return ANEURALNETWORKS_OP_FAILED;
diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h
index 94ca3fe49..efde0bdf5 100644
--- a/nn/runtime/VersionedInterfaces.h
+++ b/nn/runtime/VersionedInterfaces.h
@@ -72,12 +72,12 @@ class VersionedIDevice {
* protections.
*
* @param serviceName The name of the service that provides "device".
- * @param device A device object that is at least version 1.0 of the IDevice
- * interface.
+ * @param makeDevice A device factory function that returns a device object
+ * that is at least version 1.0 of the IDevice interface.
* @return A valid VersionedIDevice object, otherwise nullptr.
*/
static std::shared_ptr<VersionedIDevice> create(std::string serviceName,
- sp<hal::V1_0::IDevice> device);
+ const hal::DeviceFactory& makeDevice);
/**
* Constructor for the VersionedIDevice object.
@@ -92,6 +92,8 @@ class VersionedIDevice {
* @param numberOfCacheFilesNeeded Number of model cache and data cache
* files needed by the driver.
* @param serviceName The name of the service that provides core.getDevice<V1_0::IDevice>().
+ * @param makeDevice A device factory function that returns a device object
+ * that is at least version 1.0 of the IDevice interface.
* @param core An object that encapsulates a V1_0::IDevice, any appropriate downcasts to
* newer interfaces, and a hidl_death_recipient that will proactively handle
* the case when the service containing the IDevice object crashes.
@@ -100,7 +102,7 @@ class VersionedIDevice {
std::vector<hal::Extension> supportedExtensions, int32_t type,
std::string versionString,
std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
- std::string serviceName, Core core);
+ std::string serviceName, const hal::DeviceFactory& makeDevice, Core core);
/**
* Gets the capabilities of a driver.
@@ -554,6 +556,9 @@ class VersionedIDevice {
// The name of the service that implements the driver.
const std::string kServiceName;
+ // Factory function object to generate an IDevice object.
+ const hal::DeviceFactory kMakeDevice;
+
// Guards access to mCore.
mutable std::shared_mutex mMutex;