diff options
author | Michael Butler <butlermichael@google.com> | 2019-09-07 15:16:51 -0700 |
---|---|---|
committer | Michael Butler <butlermichael@google.com> | 2020-04-24 14:25:23 -0700 |
commit | 5591a18ea83233a6fd41366da16acdc3028cc7f1 (patch) | |
tree | cedd4adfd9edf717ebbbdb1c693c18fc159c494b /nn | |
parent | af55ce116ec438bf73c250ba8b9f64631f209608 (diff) | |
download | ml-5591a18ea83233a6fd41366da16acdc3028cc7f1.tar.gz |
Simplify IDevice reboot logic
HIDL allows a service to be retrieved with two functions:
* <Interface>::getService — blocks until service is retrieved
* <Interface>::tryGetService — immediately returns service or nullptr
Currently, the NN runtime retrieves the service in three different
places:
1) When the runtime first starts, <Interface>::getService is used to
acquire all services
2) When the object is dead, <Interface>::tryGetService is used to
attempt to reacquire the service, but will quickly resume if the
service is still rebooting
3) When the client calls ANNDevice_wait, <Interface>::getService is used
to block until the service is active again
This CL simplifies the IDevice reboot logic by changing these static
class functions to dependency injection. Specifically, VersionedIDevice
now retrieves a handle to the IDevice object through a DeviceFactory
function that is passed in when the VersionedIDevice object is created.
This function can either operate as blocking or nonblocking to support
all use-cases described above, and makes it easier to test the
VersionedIDevice recovery code.
Bug: 139189546
Test: mma
Test: NeuralNetworksTest_static
Test: CtsNNAPITestCases
Change-Id: I012eb4df6f09f98bdfbd0835457ba98bc22d906e
Diffstat (limited to 'nn')
-rw-r--r-- | nn/common/include/HalInterfaces.h | 1 | ||||
-rw-r--r-- | nn/runtime/Manager.cpp | 37 | ||||
-rw-r--r-- | nn/runtime/Manager.h | 5 | ||||
-rw-r--r-- | nn/runtime/Memory.cpp | 2 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.cpp | 22 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.h | 13 |
6 files changed, 49 insertions, 31 deletions
diff --git a/nn/common/include/HalInterfaces.h b/nn/common/include/HalInterfaces.h index fe1ff563e..4e3a3800b 100644 --- a/nn/common/include/HalInterfaces.h +++ b/nn/common/include/HalInterfaces.h @@ -103,6 +103,7 @@ using OperandExtraParams = V1_2::Operand::ExtraParams; using CacheToken = hardware::hidl_array<uint8_t, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)>; +using DeviceFactory = std::function<sp<V1_0::IDevice>(bool blocking)>; using ModelFactory = std::function<Model()>; inline constexpr Priority kDefaultPriority = Priority::MEDIUM; diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp index 7be8419ad..310710e3c 100644 --- a/nn/runtime/Manager.cpp +++ b/nn/runtime/Manager.cpp @@ -55,9 +55,10 @@ const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX // A Device with actual underlying driver class DriverDevice : public Device { public: - // Create a DriverDevice from a name and an IDevice. + // Create a DriverDevice from a name and a DeviceFactory function. // Returns nullptr on failure. - static std::shared_ptr<DriverDevice> create(std::string name, sp<V1_0::IDevice> device); + static std::shared_ptr<DriverDevice> create(const std::string& name, + const DeviceFactory& makeDevice); // Prefer using DriverDevice::create DriverDevice(std::shared_ptr<VersionedIDevice> device); @@ -159,6 +160,7 @@ class DriverPreparedModel : public PreparedModel { DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device) : kInterface(std::move(device)) { + CHECK(kInterface != nullptr); #ifdef NN_DEBUGGABLE static const char samplePrefix[] = "sample"; if (getName().substr(0, sizeof(samplePrefix) - 1) == samplePrefix) { @@ -167,17 +169,17 @@ DriverDevice::DriverDevice(std::shared_ptr<VersionedIDevice> device) #endif // NN_DEBUGGABLE } -std::shared_ptr<DriverDevice> DriverDevice::create(std::string name, sp<V1_0::IDevice> device) { - CHECK(device != nullptr); - std::shared_ptr<VersionedIDevice> versionedDevice = - VersionedIDevice::create(name, std::move(device)); - if (versionedDevice == nullptr) { +std::shared_ptr<DriverDevice> DriverDevice::create(const std::string& name, + const DeviceFactory& makeDevice) { + CHECK(makeDevice != nullptr); + std::shared_ptr<VersionedIDevice> device = VersionedIDevice::create(name, makeDevice); + if (device == nullptr) { LOG(ERROR) << "DriverDevice::create failed to create VersionedIDevice object for service " << name; return nullptr; } - return std::make_shared<DriverDevice>(std::move(versionedDevice)); + return std::make_shared<DriverDevice>(std::move(device)); } std::vector<bool> DriverDevice::getSupportedOperations(const MetaModel& metaModel) const { @@ -817,7 +819,8 @@ std::shared_ptr<Device> DeviceManager::getCpuDevice() { std::shared_ptr<Device> DeviceManager::forTest_makeDriverDevice(const std::string& name, const sp<V1_0::IDevice>& device) { - const auto driverDevice = DriverDevice::create(name, device); + const DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; }; + const auto driverDevice = DriverDevice::create(name, makeDevice); CHECK(driverDevice != nullptr); return driverDevice; } @@ -829,12 +832,10 @@ void DeviceManager::findAvailableDevices() { const auto names = hardware::getAllHalInstanceNames(V1_0::IDevice::descriptor); for (const auto& name : names) { VLOG(MANAGER) << "Found interface " << name; - sp<V1_0::IDevice> device = V1_0::IDevice::getService(name); - if (device == nullptr) { - LOG(ERROR) << "Got a null IDEVICE for " << name; - continue; - } - registerDevice(name, device); + const DeviceFactory makeDevice = [name](bool blocking) { + return blocking ? V1_0::IDevice::getService(name) : V1_0::IDevice::tryGetService(name); + }; + registerDevice(name, makeDevice); } // register CPU fallback device @@ -842,9 +843,9 @@ void DeviceManager::findAvailableDevices() { mDevicesCpuOnly.push_back(CpuDevice::get()); } -void DeviceManager::registerDevice(const std::string& name, const sp<V1_0::IDevice>& device) { - if (const auto d = DriverDevice::create(name, device)) { - mDevices.push_back(d); +void DeviceManager::registerDevice(const std::string& name, const DeviceFactory& makeDevice) { + if (auto device = DriverDevice::create(name, makeDevice)) { + mDevices.push_back(std::move(device)); } } diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h index c28ee49a6..d6d483576 100644 --- a/nn/runtime/Manager.h +++ b/nn/runtime/Manager.h @@ -169,7 +169,8 @@ class DeviceManager { // Register a test device. void forTest_registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device) { - registerDevice(name, device); + const hal::DeviceFactory makeDevice = [device](bool /*blocking*/) { return device; }; + registerDevice(name, makeDevice); } // Re-initialize the list of available devices. @@ -192,7 +193,7 @@ class DeviceManager { DeviceManager(); // Adds a device for the manager to use. - void registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device); + void registerDevice(const std::string& name, const hal::DeviceFactory& makeDevice); void findAvailableDevices(); diff --git a/nn/runtime/Memory.cpp b/nn/runtime/Memory.cpp index 7bfaf5562..e0bd6b953 100644 --- a/nn/runtime/Memory.cpp +++ b/nn/runtime/Memory.cpp @@ -194,7 +194,7 @@ Memory::Memory(sp<hal::IBuffer> buffer, uint32_t token) : kBuffer(std::move(buffer)), kToken(token) {} Memory::~Memory() { - for (const auto [ptr, weakBurst] : mUsedBy) { + for (const auto& [ptr, weakBurst] : mUsedBy) { if (const std::shared_ptr<ExecutionBurstController> burst = weakBurst.lock()) { burst->freeMemory(getKey()); } diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index cd39a52a7..3ae950eac 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -703,8 +703,16 @@ std::optional<InitialData> initialize(const Core& core) { } std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName, - sp<V1_0::IDevice> device) { - CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object."; + const DeviceFactory& makeDevice) { + CHECK(makeDevice != nullptr) + << "VersionedIDevice::create passed invalid device factory object."; + + // get handle to IDevice object + sp<V1_0::IDevice> device = makeDevice(/*blocking=*/true); + if (device == nullptr) { + VLOG(DRIVER) << "VersionedIDevice::create got a null IDevice for " << serviceName; + return nullptr; + } auto core = Core::create(std::move(device)); if (!core.has_value()) { @@ -722,20 +730,22 @@ std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceNa std::move(*initialData); return std::make_shared<VersionedIDevice>( std::move(capabilities), std::move(supportedExtensions), type, std::move(versionString), - numberOfCacheFilesNeeded, std::move(serviceName), std::move(core.value())); + numberOfCacheFilesNeeded, std::move(serviceName), makeDevice, std::move(core.value())); } VersionedIDevice::VersionedIDevice(hal::Capabilities capabilities, std::vector<hal::Extension> supportedExtensions, int32_t type, std::string versionString, std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded, - std::string serviceName, Core core) + std::string serviceName, const DeviceFactory& makeDevice, + Core core) : kCapabilities(std::move(capabilities)), kSupportedExtensions(std::move(supportedExtensions)), kType(type), kVersionString(std::move(versionString)), kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded), kServiceName(std::move(serviceName)), + kMakeDevice(makeDevice), mCore(std::move(core)) {} std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) { @@ -874,7 +884,7 @@ Return<T_Return> VersionedIDevice::recoverable( if (pingReturn.isDeadObject()) { VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering " << kServiceName; - sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(kServiceName); + sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/false); if (recoveredDevice == nullptr) { VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for " << kServiceName; @@ -911,7 +921,7 @@ int VersionedIDevice::wait() const { auto pingReturn = mCore.getDevice<V1_0::IDevice>()->ping(); if (pingReturn.isDeadObject()) { VLOG(DRIVER) << "VersionedIDevice::wait -- Recovering " << kServiceName; - sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::getService(kServiceName); + sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/true); if (recoveredDevice == nullptr) { LOG(ERROR) << "VersionedIDevice::wait got a null IDevice for " << kServiceName; return ANEURALNETWORKS_OP_FAILED; diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h index 94ca3fe49..efde0bdf5 100644 --- a/nn/runtime/VersionedInterfaces.h +++ b/nn/runtime/VersionedInterfaces.h @@ -72,12 +72,12 @@ class VersionedIDevice { * protections. * * @param serviceName The name of the service that provides "device". - * @param device A device object that is at least version 1.0 of the IDevice - * interface. + * @param makeDevice A device factory function that returns a device object + * that is at least version 1.0 of the IDevice interface. * @return A valid VersionedIDevice object, otherwise nullptr. */ static std::shared_ptr<VersionedIDevice> create(std::string serviceName, - sp<hal::V1_0::IDevice> device); + const hal::DeviceFactory& makeDevice); /** * Constructor for the VersionedIDevice object. @@ -92,6 +92,8 @@ class VersionedIDevice { * @param numberOfCacheFilesNeeded Number of model cache and data cache * files needed by the driver. * @param serviceName The name of the service that provides core.getDevice<V1_0::IDevice>(). + * @param makeDevice A device factory function that returns a device object + * that is at least version 1.0 of the IDevice interface. * @param core An object that encapsulates a V1_0::IDevice, any appropriate downcasts to * newer interfaces, and a hidl_death_recipient that will proactively handle * the case when the service containing the IDevice object crashes. @@ -100,7 +102,7 @@ class VersionedIDevice { std::vector<hal::Extension> supportedExtensions, int32_t type, std::string versionString, std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded, - std::string serviceName, Core core); + std::string serviceName, const hal::DeviceFactory& makeDevice, Core core); /** * Gets the capabilities of a driver. @@ -554,6 +556,9 @@ class VersionedIDevice { // The name of the service that implements the driver. const std::string kServiceName; + // Factory function object to generate an IDevice object. + const hal::DeviceFactory kMakeDevice; + // Guards access to mCore. mutable std::shared_mutex mMutex; |