diff options
author | David Gross <dgross@google.com> | 2019-05-31 02:08:32 +0000 |
---|---|---|
committer | Android (Google) Code Review <android-gerrit@google.com> | 2019-05-31 02:08:32 +0000 |
commit | 6109869f3050bcb5d03f62102de8b02f961238d2 (patch) | |
tree | 3de1b25f8b6b1357f792e0c11aa7cd4c71d3c2a1 | |
parent | 2658cca4a66ade24f9d3efdab368266f784dd06d (diff) | |
parent | 00fa8d35cbd38d5ea7cab48dc96dcb9851578aab (diff) | |
download | ml-6109869f3050bcb5d03f62102de8b02f961238d2.tar.gz |
Merge "Partially recover from a driver crash" into qt-dev
-rw-r--r-- | nn/runtime/Manager.cpp | 2 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.cpp | 363 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.h | 233 |
3 files changed, 459 insertions, 139 deletions
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp index c5d7b1fad..b479d3964 100644 --- a/nn/runtime/Manager.cpp +++ b/nn/runtime/Manager.cpp @@ -96,7 +96,7 @@ class DriverDevice : public Device { }; DriverDevice::DriverDevice(std::string name, const sp<V1_0::IDevice>& device) - : mName(std::move(name)), mInterface(VersionedIDevice::create(device)) {} + : mName(std::move(name)), mInterface(VersionedIDevice::create(mName, device)) {} // TODO: handle errors from initialize correctly bool DriverDevice::initialize() { diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index d0439b213..0a359582d 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -25,6 +25,7 @@ #include <android-base/scopeguard.h> #include <android-base/thread_annotations.h> #include <functional> +#include <type_traits> namespace android { namespace nn { @@ -40,6 +41,10 @@ void sendFailureMessage(const sp<IPreparedModelCallback>& cb) { cb->notify(ErrorStatus::GENERAL_FAILURE, nullptr); } +void sendFailureMessage(const sp<PreparedModelCallback>& cb) { + sendFailureMessage(static_cast<sp<IPreparedModelCallback>>(cb)); +} + void sendFailureMessage(const sp<IExecutionCallback>& cb) { cb->notify(ErrorStatus::GENERAL_FAILURE); } @@ -219,18 +224,33 @@ bool VersionedIPreparedModel::operator!=(nullptr_t) const { return mPreparedModelV1_0 != nullptr; } -std::shared_ptr<VersionedIDevice> VersionedIDevice::create(sp<V1_0::IDevice> device) { +std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName, + sp<V1_0::IDevice> device) { + auto core = Core::create(std::move(device)); + if (!core.has_value()) { + LOG(ERROR) << "VersionedIDevice::create -- Failed to create Core."; + return nullptr; + } + + // return a valid VersionedIDevice object + return std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value())); +} + +VersionedIDevice::VersionedIDevice(std::string serviceName, Core core) + : mServiceName(std::move(serviceName)), mCore(std::move(core)) {} + +std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) { // verify input if (!device) { - LOG(ERROR) << "VersionedIDevice::create -- passed invalid device object."; - return nullptr; + LOG(ERROR) << "VersionedIDevice::Core::create -- passed invalid device object."; + return {}; } // create death handler object sp<IDeviceDeathHandler> deathHandler = new (std::nothrow) IDeviceDeathHandler(); if (!deathHandler) { - LOG(ERROR) << "VersionedIDevice::create -- Failed to create IDeviceDeathHandler."; - return nullptr; + LOG(ERROR) << "VersionedIDevice::Core::create -- Failed to create IDeviceDeathHandler."; + return {}; } // linkToDeath registers a callback that will be invoked on service death to @@ -239,60 +259,192 @@ std::shared_ptr<VersionedIDevice> VersionedIDevice::create(sp<V1_0::IDevice> dev // providing the response. const Return<bool> ret = device->linkToDeath(deathHandler, 0); if (!ret.isOk() || ret != true) { - LOG(ERROR) << "VersionedIDevice::create -- Failed to register a death recipient for the " - "IDevice object."; - return nullptr; + LOG(ERROR) + << "VersionedIDevice::Core::create -- Failed to register a death recipient for the " + "IDevice object."; + return {}; } - // return a valid VersionedIDevice object - return std::make_shared<VersionedIDevice>(std::move(device), std::move(deathHandler)); + // return a valid Core object + return Core(std::move(device), std::move(deathHandler)); } // HIDL guarantees all V1_1 interfaces inherit from their corresponding V1_0 interfaces. -VersionedIDevice::VersionedIDevice(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler) +VersionedIDevice::Core::Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler) : mDeviceV1_0(std::move(device)), mDeviceV1_1(V1_1::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)), mDeviceV1_2(V1_2::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)), mDeathHandler(std::move(deathHandler)) {} -VersionedIDevice::~VersionedIDevice() { - // It is safe to ignore any errors resulting from this unlinkToDeath call - // because the VersionedIDevice object is already being destroyed and its - // underlying IDevice object is no longer being used by the NN runtime. - mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk(); +VersionedIDevice::Core::~Core() { + if (mDeathHandler != nullptr) { + CHECK(mDeviceV1_0 != nullptr); + // It is safe to ignore any errors resulting from this unlinkToDeath call + // because the VersionedIDevice::Core object is already being destroyed and + // its underlying IDevice object is no longer being used by the NN runtime. + mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk(); + } +} + +VersionedIDevice::Core::Core(Core&& other) noexcept + : mDeviceV1_0(std::move(other.mDeviceV1_0)), + mDeviceV1_1(std::move(other.mDeviceV1_1)), + mDeviceV1_2(std::move(other.mDeviceV1_2)), + mDeathHandler(std::move(other.mDeathHandler)) { + other.mDeathHandler = nullptr; +} + +VersionedIDevice::Core& VersionedIDevice::Core::operator=(Core&& other) noexcept { + if (this != &other) { + mDeviceV1_0 = std::move(other.mDeviceV1_0); + mDeviceV1_1 = std::move(other.mDeviceV1_1); + mDeviceV1_2 = std::move(other.mDeviceV1_2); + mDeathHandler = std::move(other.mDeathHandler); + other.mDeathHandler = nullptr; + } + return *this; +} + +template <typename T_IDevice> +std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> VersionedIDevice::Core::getDeviceAndDeathHandler() + const { + return {getDevice<T_IDevice>(), mDeathHandler}; +} + +template <typename T_IDevice, typename T_Callback> +Return<ErrorStatus> callProtected( + const char* context, const std::function<Return<ErrorStatus>(const sp<T_IDevice>&)>& fn, + const sp<T_IDevice>& device, const sp<T_Callback>& callback, + const sp<IDeviceDeathHandler>& deathHandler) { + const auto scoped = deathHandler->protectCallback(callback); + Return<ErrorStatus> ret = fn(device); + // Suppose there was a transport error. We have the following cases: + // 1. Either not due to a dead device, or due to a device that was + // already dead at the time of the call to protectCallback(). In + // this case, the callback was never signalled. + // 2. Due to a device that died after the call to protectCallback() but + // before fn() completed. In this case, the callback was (or will + // be) signalled by the deathHandler. + // Furthermore, what if there was no transport error, but the ErrorStatus is + // other than NONE? We'll conservatively signal the callback anyway, just in + // case the driver was sloppy and failed to do so. + if (!ret.isOk() || ret != ErrorStatus::NONE) { + // What if the deathHandler has signalled or will signal the callback? + // This is fine -- we're permitted to signal multiple times; and we're + // sending the same signal that the deathHandler does. + // + // What if the driver signalled the callback? Then this signal is + // ignored. + + if (ret.isOk()) { + LOG(ERROR) << context << " returned " << toString(static_cast<ErrorStatus>(ret)); + } else { + LOG(ERROR) << context << " failure: " << ret.description(); + } + sendFailureMessage(callback); + } + callback->wait(); + return ret; +} +template <typename T_Return, typename T_IDevice> +Return<T_Return> callProtected(const char*, + const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn, + const sp<T_IDevice>& device, const std::nullptr_t&, + const sp<IDeviceDeathHandler>&) { + return fn(device); +} + +template <typename T_Return, typename T_IDevice, typename T_Callback> +Return<T_Return> VersionedIDevice::recoverable( + const char* context, const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn, + const T_Callback& callback) const EXCLUDES(mMutex) { + CHECK_EQ(callback == nullptr, (std::is_same_v<T_Callback, std::nullptr_t>)); + + sp<T_IDevice> device; + sp<IDeviceDeathHandler> deathHandler; + std::tie(device, deathHandler) = getDeviceAndDeathHandler<T_IDevice>(); + + Return<T_Return> ret = callProtected(context, fn, device, callback, deathHandler); + + if (ret.isDeadObject()) { + { + std::unique_lock lock(mMutex); + // It's possible that another device has already done the recovery. + // It's harmless but wasteful for us to do so in this case. + auto pingReturn = mCore.getDevice<T_IDevice>()->ping(); + if (pingReturn.isDeadObject()) { + VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering " + << mServiceName; + sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(mServiceName); + if (recoveredDevice == nullptr) { + VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for " + << mServiceName; + return ret; + } + + auto core = Core::create(std::move(recoveredDevice)); + if (!core.has_value()) { + LOG(ERROR) << "VersionedIDevice::recoverable -- Failed to create Core."; + return ret; + } + + mCore = std::move(core.value()); + } else { + VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context + << ") -- Someone else recovered " << mServiceName; + // Might still have a transport error, which we need to check + // before pingReturn goes out of scope. + (void)pingReturn.isOk(); + } + std::tie(device, deathHandler) = mCore.getDeviceAndDeathHandler<T_IDevice>(); + } + ret = callProtected(context, fn, device, callback, deathHandler); + // It's possible that the device died again, but we're only going to + // attempt recovery once per call to recoverable(). + } + return ret; } std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() { const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}}; std::pair<ErrorStatus, Capabilities> result; - if (mDeviceV1_2 != nullptr) { + if (getDevice<V1_2::IDevice>() != nullptr) { NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_2"); - Return<void> ret = mDeviceV1_2->getCapabilities_1_2( - [&result](ErrorStatus error, const Capabilities& capabilities) { - result = std::make_pair(error, capabilities); + Return<void> ret = recoverable<void, V1_2::IDevice>( + __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) { + return device->getCapabilities_1_2( + [&result](ErrorStatus error, const Capabilities& capabilities) { + result = std::make_pair(error, capabilities); + }); }); if (!ret.isOk()) { LOG(ERROR) << "getCapabilities_1_2 failure: " << ret.description(); return {ErrorStatus::GENERAL_FAILURE, {}}; } - } else if (mDeviceV1_1 != nullptr) { + } else if (getDevice<V1_1::IDevice>() != nullptr) { NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_1"); - Return<void> ret = mDeviceV1_1->getCapabilities_1_1( - [&result](ErrorStatus error, const V1_1::Capabilities& capabilities) { - // Time taken to convert capabilities is trivial - result = std::make_pair(error, convertToV1_2(capabilities)); + Return<void> ret = recoverable<void, V1_1::IDevice>( + __FUNCTION__, [&result](const sp<V1_1::IDevice>& device) { + return device->getCapabilities_1_1( + [&result](ErrorStatus error, const V1_1::Capabilities& capabilities) { + // Time taken to convert capabilities is trivial + result = std::make_pair(error, convertToV1_2(capabilities)); + }); }); if (!ret.isOk()) { LOG(ERROR) << "getCapabilities_1_1 failure: " << ret.description(); return kFailure; } - } else if (mDeviceV1_0 != nullptr) { + } else if (getDevice<V1_0::IDevice>() != nullptr) { NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities"); - Return<void> ret = mDeviceV1_0->getCapabilities( - [&result](ErrorStatus error, const V1_0::Capabilities& capabilities) { - // Time taken to convert capabilities is trivial - result = std::make_pair(error, convertToV1_2(capabilities)); + Return<void> ret = recoverable<void, V1_0::IDevice>( + __FUNCTION__, [&result](const sp<V1_0::IDevice>& device) { + return device->getCapabilities( + [&result](ErrorStatus error, const V1_0::Capabilities& capabilities) { + // Time taken to convert capabilities is trivial + result = std::make_pair(error, convertToV1_2(capabilities)); + }); }); if (!ret.isOk()) { LOG(ERROR) << "getCapabilities failure: " << ret.description(); @@ -309,18 +461,21 @@ std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() { std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() { const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}}; NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedExtensions"); - if (mDeviceV1_2 != nullptr) { + if (getDevice<V1_2::IDevice>() != nullptr) { std::pair<ErrorStatus, hidl_vec<Extension>> result; - Return<void> ret = mDeviceV1_2->getSupportedExtensions( - [&result](ErrorStatus error, const hidl_vec<Extension>& extensions) { - result = std::make_pair(error, extensions); + Return<void> ret = recoverable<void, V1_2::IDevice>( + __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) { + return device->getSupportedExtensions( + [&result](ErrorStatus error, const hidl_vec<Extension>& extensions) { + result = std::make_pair(error, extensions); + }); }); if (!ret.isOk()) { LOG(ERROR) << "getSupportedExtensions failure: " << ret.description(); return kFailure; } return result; - } else if (mDeviceV1_0 != nullptr) { + } else if (getDevice<V1_0::IDevice>() != nullptr) { return {ErrorStatus::NONE, {/* No extensions. */}}; } else { LOG(ERROR) << "Device not available!"; @@ -354,11 +509,14 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations( return std::make_pair(status, std::move(remappedSupported)); }; - if (mDeviceV1_2 != nullptr) { + if (getDevice<V1_2::IDevice>() != nullptr) { NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_2"); - Return<void> ret = mDeviceV1_2->getSupportedOperations_1_2( - model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) { - result = std::make_pair(error, supported); + Return<void> ret = recoverable<void, V1_2::IDevice>( + __FUNCTION__, [&model, &result](const sp<V1_2::IDevice>& device) { + return device->getSupportedOperations_1_2( + model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) { + result = std::make_pair(error, supported); + }); }); if (!ret.isOk()) { LOG(ERROR) << "getSupportedOperations_1_2 failure: " << ret.description(); @@ -367,7 +525,7 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations( return result; } - if (mDeviceV1_1 != nullptr) { + if (getDevice<V1_1::IDevice>() != nullptr) { const bool compliant = compliantWithV1_1(model); if (compliant || slicer) { V1_1::Model model11; @@ -383,9 +541,13 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations( } NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_1"); - Return<void> ret = mDeviceV1_1->getSupportedOperations_1_1( - model11, [&result](ErrorStatus error, const hidl_vec<bool>& supported) { - result = std::make_pair(error, supported); + Return<void> ret = recoverable<void, V1_1::IDevice>( + __FUNCTION__, [&model11, &result](const sp<V1_1::IDevice>& device) { + return device->getSupportedOperations_1_1( + model11, + [&result](ErrorStatus error, const hidl_vec<bool>& supported) { + result = std::make_pair(error, supported); + }); }); if (!ret.isOk()) { LOG(ERROR) << "getSupportedOperations_1_1 failure: " << ret.description(); @@ -398,7 +560,7 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations( return result; } - if (mDeviceV1_0 != nullptr) { + if (getDevice<V1_0::IDevice>() != nullptr) { const bool compliant = compliantWithV1_0(model); if (compliant || slicer) { V1_0::Model model10; @@ -413,9 +575,13 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations( std::tie(model10, submodelOperationIndexToModelOperationIndex) = *slice10; } NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations"); - Return<void> ret = mDeviceV1_0->getSupportedOperations( - model10, [&result](ErrorStatus error, const hidl_vec<bool>& supported) { - result = std::make_pair(error, supported); + Return<void> ret = recoverable<void, V1_0::IDevice>( + __FUNCTION__, [&model10, &result](const sp<V1_0::IDevice>& device) { + return device->getSupportedOperations( + model10, + [&result](ErrorStatus error, const hidl_vec<bool>& supported) { + result = std::make_pair(error, supported); + }); }); if (!ret.isOk()) { LOG(ERROR) << "getSupportedOperations failure: " << ret.description(); @@ -443,12 +609,16 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic return kFailure; } - const auto scoped = mDeathHandler->protectCallback(callback); - // If 1.2 device, try preparing model - if (mDeviceV1_2 != nullptr) { - const Return<ErrorStatus> ret = mDeviceV1_2->prepareModel_1_2(model, preference, modelCache, - dataCache, token, callback); + if (getDevice<V1_2::IDevice>() != nullptr) { + const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>( + __FUNCTION__, + [&model, &preference, &modelCache, &dataCache, &token, + &callback](const sp<V1_2::IDevice>& device) { + return device->prepareModel_1_2(model, preference, modelCache, dataCache, token, + callback); + }, + callback); if (!ret.isOk()) { LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description(); return kFailure; @@ -462,7 +632,7 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic } // If 1.1 device, try preparing model (requires conversion) - if (mDeviceV1_1 != nullptr) { + if (getDevice<V1_1::IDevice>() != nullptr) { bool compliant = false; V1_1::Model model11; { @@ -477,8 +647,12 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic } } if (compliant) { - const Return<ErrorStatus> ret = - mDeviceV1_1->prepareModel_1_1(model11, preference, callback); + const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_1::IDevice>( + __FUNCTION__, + [&model11, &preference, &callback](const sp<V1_1::IDevice>& device) { + return device->prepareModel_1_1(model11, preference, callback); + }, + callback); if (!ret.isOk()) { LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description(); return kFailure; @@ -493,14 +667,12 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic makeVersionedIPreparedModel(callback->getPreparedModel())}; } - // TODO: partition the model such that v1.2 ops are not passed to v1.1 - // device LOG(ERROR) << "Could not handle prepareModel_1_1!"; return kFailure; } // If 1.0 device, try preparing model (requires conversion) - if (mDeviceV1_0 != nullptr) { + if (getDevice<V1_0::IDevice>() != nullptr) { bool compliant = false; V1_0::Model model10; { @@ -515,7 +687,12 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic } } if (compliant) { - const Return<ErrorStatus> ret = mDeviceV1_0->prepareModel(model10, callback); + const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_0::IDevice>( + __FUNCTION__, + [&model10, &callback](const sp<V1_0::IDevice>& device) { + return device->prepareModel(model10, callback); + }, + callback); if (!ret.isOk()) { LOG(ERROR) << "prepareModel failure: " << ret.description(); return kFailure; @@ -529,8 +706,6 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic makeVersionedIPreparedModel(callback->getPreparedModel())}; } - // TODO: partition the model such that v1.1 ops are not passed to v1.0 - // device LOG(ERROR) << "Could not handle prepareModel!"; return kFailure; } @@ -553,11 +728,13 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache, return kFailure; } - const auto scoped = mDeathHandler->protectCallback(callback); - - if (mDeviceV1_2 != nullptr) { - const Return<ErrorStatus> ret = - mDeviceV1_2->prepareModelFromCache(modelCache, dataCache, token, callback); + if (getDevice<V1_2::IDevice>() != nullptr) { + const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>( + __FUNCTION__, + [&modelCache, &dataCache, &token, &callback](const sp<V1_2::IDevice>& device) { + return device->prepareModelFromCache(modelCache, dataCache, token, callback); + }, + callback); if (!ret.isOk()) { LOG(ERROR) << "prepareModelFromCache failure: " << ret.description(); return kFailure; @@ -571,7 +748,7 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache, return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())}; } - if (mDeviceV1_1 != nullptr || mDeviceV1_0 != nullptr) { + if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) { LOG(ERROR) << "prepareModelFromCache called on V1_1 or V1_0 device"; return kFailure; } @@ -581,12 +758,13 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache, } DeviceStatus VersionedIDevice::getStatus() { - if (mDeviceV1_0 == nullptr) { + if (getDevice<V1_0::IDevice>() == nullptr) { LOG(ERROR) << "Device not available!"; return DeviceStatus::UNKNOWN; } - Return<DeviceStatus> ret = mDeviceV1_0->getStatus(); + Return<DeviceStatus> ret = recoverable<DeviceStatus, V1_0::IDevice>( + __FUNCTION__, [](const sp<V1_0::IDevice>& device) { return device->getStatus(); }); if (!ret.isOk()) { LOG(ERROR) << "getStatus failure: " << ret.description(); @@ -598,11 +776,11 @@ DeviceStatus VersionedIDevice::getStatus() { int64_t VersionedIDevice::getFeatureLevel() { constexpr int64_t kFailure = -1; - if (mDeviceV1_2 != nullptr) { + if (getDevice<V1_2::IDevice>() != nullptr) { return __ANDROID_API_Q__; - } else if (mDeviceV1_1 != nullptr) { + } else if (getDevice<V1_1::IDevice>() != nullptr) { return __ANDROID_API_P__; - } else if (mDeviceV1_0 != nullptr) { + } else if (getDevice<V1_0::IDevice>() != nullptr) { return __ANDROID_API_O_MR1__; } else { LOG(ERROR) << "Device not available!"; @@ -614,10 +792,12 @@ int32_t VersionedIDevice::getType() const { constexpr int32_t kFailure = -1; std::pair<ErrorStatus, DeviceType> result; - if (mDeviceV1_2 != nullptr) { - Return<void> ret = - mDeviceV1_2->getType([&result](ErrorStatus error, DeviceType deviceType) { - result = std::make_pair(error, deviceType); + if (getDevice<V1_2::IDevice>() != nullptr) { + Return<void> ret = recoverable<void, V1_2::IDevice>( + __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) { + return device->getType([&result](ErrorStatus error, DeviceType deviceType) { + result = std::make_pair(error, deviceType); + }); }); if (!ret.isOk()) { LOG(ERROR) << "getType failure: " << ret.description(); @@ -634,17 +814,20 @@ std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() { const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""}; std::pair<ErrorStatus, hidl_string> result; - if (mDeviceV1_2 != nullptr) { - Return<void> ret = mDeviceV1_2->getVersionString( - [&result](ErrorStatus error, const hidl_string& version) { - result = std::make_pair(error, version); + if (getDevice<V1_2::IDevice>() != nullptr) { + Return<void> ret = recoverable<void, V1_2::IDevice>( + __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) { + return device->getVersionString( + [&result](ErrorStatus error, const hidl_string& version) { + result = std::make_pair(error, version); + }); }); if (!ret.isOk()) { LOG(ERROR) << "getVersion failure: " << ret.description(); return kFailure; } return result; - } else if (mDeviceV1_1 != nullptr || mDeviceV1_0 != nullptr) { + } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) { return {ErrorStatus::NONE, "UNKNOWN"}; } else { LOG(ERROR) << "Could not handle getVersionString"; @@ -657,17 +840,21 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi 0, 0}; std::tuple<ErrorStatus, uint32_t, uint32_t> result; - if (mDeviceV1_2 != nullptr) { - Return<void> ret = mDeviceV1_2->getNumberOfCacheFilesNeeded( - [&result](ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) { - result = {error, numModelCache, numDataCache}; + if (getDevice<V1_2::IDevice>() != nullptr) { + Return<void> ret = recoverable<void, V1_2::IDevice>( + __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) { + return device->getNumberOfCacheFilesNeeded([&result](ErrorStatus error, + uint32_t numModelCache, + uint32_t numDataCache) { + result = {error, numModelCache, numDataCache}; + }); }); if (!ret.isOk()) { LOG(ERROR) << "getNumberOfCacheFilesNeeded failure: " << ret.description(); return kFailure; } return result; - } else if (mDeviceV1_1 != nullptr || mDeviceV1_0 != nullptr) { + } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) { return {ErrorStatus::NONE, 0, 0}; } else { LOG(ERROR) << "Could not handle getNumberOfCacheFilesNeeded"; @@ -676,11 +863,11 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi } bool VersionedIDevice::operator==(nullptr_t) const { - return mDeviceV1_0 == nullptr; + return getDevice<V1_0::IDevice>() == nullptr; } bool VersionedIDevice::operator!=(nullptr_t) const { - return mDeviceV1_0 != nullptr; + return getDevice<V1_0::IDevice>() != nullptr; } } // namespace nn diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h index 81a66013f..f0aaf6d6e 100644 --- a/nn/runtime/VersionedInterfaces.h +++ b/nn/runtime/VersionedInterfaces.h @@ -20,9 +20,14 @@ #include "HalInterfaces.h" #include <android-base/macros.h> +#include <cstddef> +#include <functional> #include <memory> +#include <optional> +#include <shared_mutex> #include <string> #include <tuple> +#include <utility> #include "Callbacks.h" namespace android { @@ -53,6 +58,9 @@ class VersionedIPreparedModel; class VersionedIDevice { DISALLOW_IMPLICIT_CONSTRUCTORS(VersionedIDevice); + // forward declaration of nested class + class Core; + public: /** * Create a VersionedIDevice object. @@ -60,40 +68,26 @@ class VersionedIDevice { * Prefer using this function over the constructor, as it adds more * protections. * - * This call linksToDeath a hidl_death_recipient that can - * proactively handle the case when the service containing the IDevice - * object crashes. - * + * @param serviceName The name of the service that provides "device". * @param device A device object that is at least version 1.0 of the IDevice * interface. * @return A valid VersionedIDevice object, otherwise nullptr. */ - static std::shared_ptr<VersionedIDevice> create(sp<V1_0::IDevice> device); + static std::shared_ptr<VersionedIDevice> create(std::string serviceName, + sp<V1_0::IDevice> device); /** * Constructor for the VersionedIDevice object. * - * VersionedIDevice is constructed with the V1_0::IDevice object, which - * represents a device that is at least v1.0 of the interface. The - * constructor downcasts to the latest version of the IDevice interface, and - * will default to using the latest version of all IDevice interface - * methods automatically. - * - * @param device A device object that is at least version 1.0 of the IDevice - * interface. - * @param deathHandler A hidl_death_recipient that will proactively handle - * the case when the service containing the IDevice - * object crashes. - */ - VersionedIDevice(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler); - - /** - * Destructor for the VersionedIDevice object. + * VersionedIDevice will default to using the latest version of all IDevice + * interface methods automatically. * - * This destructor unlinksToDeath this object's hidl_death_recipient as it - * no longer needs to handle the case where the IDevice's service crashes. + * @param serviceName The name of the service that provides core.getDevice<V1_0::IDevice>(). + * @param core An object that encapsulates a V1_0::IDevice, any appropriate downcasts to + * newer interfaces, and a hidl_death_recipient that will proactively handle + * the case when the service containing the IDevice object crashes. */ - ~VersionedIDevice(); + VersionedIDevice(std::string serviceName, Core core); /** * Gets the capabilities of a driver. @@ -426,6 +420,13 @@ class VersionedIDevice { std::tuple<ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeeded(); /** + * Returns the name of the service that implements the driver + * + * @return serviceName The name of the service. + */ + std::string getServiceName() const { return mServiceName; } + + /** * Returns whether this handle to an IDevice object is valid or not. * * @return bool true if V1_0::IDevice (which could be V1_1::IDevice) is @@ -443,33 +444,165 @@ class VersionedIDevice { private: /** - * All versions of IDevice are necessary because the driver could be v1.0, - * v1.1, or a later version. All these pointers logically represent the same - * object. - * - * The general strategy is: HIDL returns a V1_0 device object, which - * (if not nullptr) could be v1.0, v1.1, or a greater version. The V1_0 - * object is then "dynamically cast" to a V1_1 object. If successful, - * mDeviceV1_1 will point to the same object as mDeviceV1_0; otherwise, - * mDeviceV1_1 will be nullptr. - * - * In general: - * * If the device is truly v1.0, mDeviceV1_0 will point to a valid object - * and mDeviceV1_1 will be nullptr. - * * If the device is truly v1.1 or later, both mDeviceV1_0 and mDeviceV1_1 - * will point to the same valid object. - * - * Idiomatic usage: if mDeviceV1_1 is non-null, do V1_1 dispatch; otherwise, - * do V1_0 dispatch. - */ - sp<V1_0::IDevice> mDeviceV1_0; - sp<V1_1::IDevice> mDeviceV1_1; - sp<V1_2::IDevice> mDeviceV1_2; - - /** - * HIDL callback to be invoked if the service for mDeviceV1_0 crashes. + * This is a utility class for VersionedIDevice that encapsulates a + * V1_0::IDevice, any appropriate downcasts to newer interfaces, and a + * hidl_death_recipient that will proactively handle the case when the + * service containing the IDevice object crashes. + * + * This is a convenience class to help VersionedIDevice recover from an + * IDevice object crash: It bundles together all the data that needs to + * change when recovering from a crash, and simplifies the process of + * instantiating that data (at VersionedIDevice creation time) and + * re-instantiating that data (at crash recovery time). */ - const sp<IDeviceDeathHandler> mDeathHandler; + class Core { + public: + /** + * Constructor for the Core object. + * + * Core is constructed with a V1_0::IDevice object, which represents a + * device that is at least v1.0 of the interface. The constructor + * downcasts to the latest version of the IDevice interface, allowing + * VersionedIDevice to default to using the latest version of all + * IDevice interface methods automatically. + * + * @param device A device object that is at least version 1.0 of the IDevice + * interface. + * @param deathHandler A hidl_death_recipient that will proactively handle + * the case when the service containing the IDevice + * object crashes. + */ + Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler); + + /** + * Destructor for the Core object. + * + * This destructor unlinksToDeath this object's hidl_death_recipient as it + * no longer needs to handle the case where the IDevice's service crashes. + */ + ~Core(); + + // Support move but not copy + Core(Core&&) noexcept; + Core& operator=(Core&&) noexcept; + Core(const Core&) = delete; + Core& operator=(const Core&) = delete; + + /** + * Create a Core object. + * + * Prefer using this function over the constructor, as it adds more + * protections. + * + * This call linksToDeath a hidl_death_recipient that can + * proactively handle the case when the service containing the IDevice + * object crashes. + * + * @param device A device object that is at least version 1.0 of the IDevice + * interface. + * @return A valid Core object, otherwise nullopt. + */ + static std::optional<Core> create(sp<V1_0::IDevice> device); + + /** + * Returns sp<*::IDevice> that is a downcast of the sp<V1_0::IDevice> + * passed to the constructor. This will be nullptr if that IDevice is + * not actually of the specified downcast type. + */ + template <typename T_IDevice> + sp<T_IDevice> getDevice() const; + template <> + sp<V1_0::IDevice> getDevice() const { + return mDeviceV1_0; + } + template <> + sp<V1_1::IDevice> getDevice() const { + return mDeviceV1_1; + } + template <> + sp<V1_2::IDevice> getDevice() const { + return mDeviceV1_2; + } + + /** + * Returns sp<*::IDevice> (as per getDevice()) and the + * hidl_death_recipient that will proactively handle the case when the + * service containing the IDevice object crashes. + */ + template <typename T_IDevice> + std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> getDeviceAndDeathHandler() const; + + private: + /** + * All versions of IDevice are necessary because the driver could be v1.0, + * v1.1, or a later version. All these pointers logically represent the same + * object. + * + * The general strategy is: HIDL returns a V1_0 device object, which + * (if not nullptr) could be v1.0, v1.1, or a greater version. The V1_0 + * object is then "dynamically cast" to a V1_1 object. If successful, + * mDeviceV1_1 will point to the same object as mDeviceV1_0; otherwise, + * mDeviceV1_1 will be nullptr. + * + * In general: + * * If the device is truly v1.0, mDeviceV1_0 will point to a valid object + * and mDeviceV1_1 will be nullptr. + * * If the device is truly v1.1 or later, both mDeviceV1_0 and mDeviceV1_1 + * will point to the same valid object. + * + * Idiomatic usage: if mDeviceV1_1 is non-null, do V1_1 dispatch; otherwise, + * do V1_0 dispatch. + */ + sp<V1_0::IDevice> mDeviceV1_0; + sp<V1_1::IDevice> mDeviceV1_1; + sp<V1_2::IDevice> mDeviceV1_2; + + /** + * HIDL callback to be invoked if the service for mDeviceV1_0 crashes. + * + * nullptr if this Core instance is a move victim and hence has no + * callback to be unlinked. + */ + sp<IDeviceDeathHandler> mDeathHandler; + }; + + // This method retrieves the appropriate mCore.mDevice* field, under a read lock. + template <typename T_IDevice> + sp<T_IDevice> getDevice() const EXCLUDES(mMutex) { + std::shared_lock lock(mMutex); + return mCore.getDevice<T_IDevice>(); + } + + // This method retrieves the appropriate mCore.mDevice* fields, under a read lock. + template <typename T_IDevice> + auto getDeviceAndDeathHandler() const EXCLUDES(mMutex) { + std::shared_lock lock(mMutex); + return mCore.getDeviceAndDeathHandler<T_IDevice>(); + } + + // This method calls the function fn in a manner that supports recovering + // from a driver crash: If the driver implementation is dead because the + // driver crashed either before the call to fn or during the call to fn, we + // will attempt to obtain a new instance of the same driver and call fn + // again. + // + // If a callback is provided, this method protects it against driver death + // and waits for it (callback->wait()). + template <typename T_Return, typename T_IDevice, typename T_Callback = std::nullptr_t> + Return<T_Return> recoverable(const char* context, + const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn, + const T_Callback& callback = nullptr) const EXCLUDES(mMutex); + + // The name of the service that implements the driver. + const std::string mServiceName; + + // Guards access to mCore. + mutable std::shared_mutex mMutex; + + // Data that can be rewritten during driver recovery. Guarded againt + // synchronous access by a mutex: Any number of concurrent read accesses is + // permitted, but a write access excludes all other accesses. + mutable Core mCore GUARDED_BY(mMutex); }; /** This class wraps an IPreparedModel object of any version. */ |