summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Gross <dgross@google.com>2019-05-31 02:08:32 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2019-05-31 02:08:32 +0000
commit6109869f3050bcb5d03f62102de8b02f961238d2 (patch)
tree3de1b25f8b6b1357f792e0c11aa7cd4c71d3c2a1
parent2658cca4a66ade24f9d3efdab368266f784dd06d (diff)
parent00fa8d35cbd38d5ea7cab48dc96dcb9851578aab (diff)
downloadml-6109869f3050bcb5d03f62102de8b02f961238d2.tar.gz
Merge "Partially recover from a driver crash" into qt-dev
-rw-r--r--nn/runtime/Manager.cpp2
-rw-r--r--nn/runtime/VersionedInterfaces.cpp363
-rw-r--r--nn/runtime/VersionedInterfaces.h233
3 files changed, 459 insertions, 139 deletions
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index c5d7b1fad..b479d3964 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -96,7 +96,7 @@ class DriverDevice : public Device {
};
DriverDevice::DriverDevice(std::string name, const sp<V1_0::IDevice>& device)
- : mName(std::move(name)), mInterface(VersionedIDevice::create(device)) {}
+ : mName(std::move(name)), mInterface(VersionedIDevice::create(mName, device)) {}
// TODO: handle errors from initialize correctly
bool DriverDevice::initialize() {
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index d0439b213..0a359582d 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -25,6 +25,7 @@
#include <android-base/scopeguard.h>
#include <android-base/thread_annotations.h>
#include <functional>
+#include <type_traits>
namespace android {
namespace nn {
@@ -40,6 +41,10 @@ void sendFailureMessage(const sp<IPreparedModelCallback>& cb) {
cb->notify(ErrorStatus::GENERAL_FAILURE, nullptr);
}
+void sendFailureMessage(const sp<PreparedModelCallback>& cb) {
+ sendFailureMessage(static_cast<sp<IPreparedModelCallback>>(cb));
+}
+
void sendFailureMessage(const sp<IExecutionCallback>& cb) {
cb->notify(ErrorStatus::GENERAL_FAILURE);
}
@@ -219,18 +224,33 @@ bool VersionedIPreparedModel::operator!=(nullptr_t) const {
return mPreparedModelV1_0 != nullptr;
}
-std::shared_ptr<VersionedIDevice> VersionedIDevice::create(sp<V1_0::IDevice> device) {
+std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
+ sp<V1_0::IDevice> device) {
+ auto core = Core::create(std::move(device));
+ if (!core.has_value()) {
+ LOG(ERROR) << "VersionedIDevice::create -- Failed to create Core.";
+ return nullptr;
+ }
+
+ // return a valid VersionedIDevice object
+ return std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value()));
+}
+
+VersionedIDevice::VersionedIDevice(std::string serviceName, Core core)
+ : mServiceName(std::move(serviceName)), mCore(std::move(core)) {}
+
+std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
// verify input
if (!device) {
- LOG(ERROR) << "VersionedIDevice::create -- passed invalid device object.";
- return nullptr;
+ LOG(ERROR) << "VersionedIDevice::Core::create -- passed invalid device object.";
+ return {};
}
// create death handler object
sp<IDeviceDeathHandler> deathHandler = new (std::nothrow) IDeviceDeathHandler();
if (!deathHandler) {
- LOG(ERROR) << "VersionedIDevice::create -- Failed to create IDeviceDeathHandler.";
- return nullptr;
+ LOG(ERROR) << "VersionedIDevice::Core::create -- Failed to create IDeviceDeathHandler.";
+ return {};
}
// linkToDeath registers a callback that will be invoked on service death to
@@ -239,60 +259,192 @@ std::shared_ptr<VersionedIDevice> VersionedIDevice::create(sp<V1_0::IDevice> dev
// providing the response.
const Return<bool> ret = device->linkToDeath(deathHandler, 0);
if (!ret.isOk() || ret != true) {
- LOG(ERROR) << "VersionedIDevice::create -- Failed to register a death recipient for the "
- "IDevice object.";
- return nullptr;
+ LOG(ERROR)
+ << "VersionedIDevice::Core::create -- Failed to register a death recipient for the "
+ "IDevice object.";
+ return {};
}
- // return a valid VersionedIDevice object
- return std::make_shared<VersionedIDevice>(std::move(device), std::move(deathHandler));
+ // return a valid Core object
+ return Core(std::move(device), std::move(deathHandler));
}
// HIDL guarantees all V1_1 interfaces inherit from their corresponding V1_0 interfaces.
-VersionedIDevice::VersionedIDevice(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler)
+VersionedIDevice::Core::Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler)
: mDeviceV1_0(std::move(device)),
mDeviceV1_1(V1_1::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
mDeviceV1_2(V1_2::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
mDeathHandler(std::move(deathHandler)) {}
-VersionedIDevice::~VersionedIDevice() {
- // It is safe to ignore any errors resulting from this unlinkToDeath call
- // because the VersionedIDevice object is already being destroyed and its
- // underlying IDevice object is no longer being used by the NN runtime.
- mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk();
+VersionedIDevice::Core::~Core() {
+ if (mDeathHandler != nullptr) {
+ CHECK(mDeviceV1_0 != nullptr);
+ // It is safe to ignore any errors resulting from this unlinkToDeath call
+ // because the VersionedIDevice::Core object is already being destroyed and
+ // its underlying IDevice object is no longer being used by the NN runtime.
+ mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk();
+ }
+}
+
+VersionedIDevice::Core::Core(Core&& other) noexcept
+ : mDeviceV1_0(std::move(other.mDeviceV1_0)),
+ mDeviceV1_1(std::move(other.mDeviceV1_1)),
+ mDeviceV1_2(std::move(other.mDeviceV1_2)),
+ mDeathHandler(std::move(other.mDeathHandler)) {
+ other.mDeathHandler = nullptr;
+}
+
+VersionedIDevice::Core& VersionedIDevice::Core::operator=(Core&& other) noexcept {
+ if (this != &other) {
+ mDeviceV1_0 = std::move(other.mDeviceV1_0);
+ mDeviceV1_1 = std::move(other.mDeviceV1_1);
+ mDeviceV1_2 = std::move(other.mDeviceV1_2);
+ mDeathHandler = std::move(other.mDeathHandler);
+ other.mDeathHandler = nullptr;
+ }
+ return *this;
+}
+
+template <typename T_IDevice>
+std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> VersionedIDevice::Core::getDeviceAndDeathHandler()
+ const {
+ return {getDevice<T_IDevice>(), mDeathHandler};
+}
+
+template <typename T_IDevice, typename T_Callback>
+Return<ErrorStatus> callProtected(
+ const char* context, const std::function<Return<ErrorStatus>(const sp<T_IDevice>&)>& fn,
+ const sp<T_IDevice>& device, const sp<T_Callback>& callback,
+ const sp<IDeviceDeathHandler>& deathHandler) {
+ const auto scoped = deathHandler->protectCallback(callback);
+ Return<ErrorStatus> ret = fn(device);
+ // Suppose there was a transport error. We have the following cases:
+ // 1. Either not due to a dead device, or due to a device that was
+ // already dead at the time of the call to protectCallback(). In
+ // this case, the callback was never signalled.
+ // 2. Due to a device that died after the call to protectCallback() but
+ // before fn() completed. In this case, the callback was (or will
+ // be) signalled by the deathHandler.
+ // Furthermore, what if there was no transport error, but the ErrorStatus is
+ // other than NONE? We'll conservatively signal the callback anyway, just in
+ // case the driver was sloppy and failed to do so.
+ if (!ret.isOk() || ret != ErrorStatus::NONE) {
+ // What if the deathHandler has signalled or will signal the callback?
+ // This is fine -- we're permitted to signal multiple times; and we're
+ // sending the same signal that the deathHandler does.
+ //
+ // What if the driver signalled the callback? Then this signal is
+ // ignored.
+
+ if (ret.isOk()) {
+ LOG(ERROR) << context << " returned " << toString(static_cast<ErrorStatus>(ret));
+ } else {
+ LOG(ERROR) << context << " failure: " << ret.description();
+ }
+ sendFailureMessage(callback);
+ }
+ callback->wait();
+ return ret;
+}
+template <typename T_Return, typename T_IDevice>
+Return<T_Return> callProtected(const char*,
+ const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
+ const sp<T_IDevice>& device, const std::nullptr_t&,
+ const sp<IDeviceDeathHandler>&) {
+ return fn(device);
+}
+
+template <typename T_Return, typename T_IDevice, typename T_Callback>
+Return<T_Return> VersionedIDevice::recoverable(
+ const char* context, const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
+ const T_Callback& callback) const EXCLUDES(mMutex) {
+ CHECK_EQ(callback == nullptr, (std::is_same_v<T_Callback, std::nullptr_t>));
+
+ sp<T_IDevice> device;
+ sp<IDeviceDeathHandler> deathHandler;
+ std::tie(device, deathHandler) = getDeviceAndDeathHandler<T_IDevice>();
+
+ Return<T_Return> ret = callProtected(context, fn, device, callback, deathHandler);
+
+ if (ret.isDeadObject()) {
+ {
+ std::unique_lock lock(mMutex);
+ // It's possible that another device has already done the recovery.
+ // It's harmless but wasteful for us to do so in this case.
+ auto pingReturn = mCore.getDevice<T_IDevice>()->ping();
+ if (pingReturn.isDeadObject()) {
+ VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering "
+ << mServiceName;
+ sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(mServiceName);
+ if (recoveredDevice == nullptr) {
+ VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for "
+ << mServiceName;
+ return ret;
+ }
+
+ auto core = Core::create(std::move(recoveredDevice));
+ if (!core.has_value()) {
+ LOG(ERROR) << "VersionedIDevice::recoverable -- Failed to create Core.";
+ return ret;
+ }
+
+ mCore = std::move(core.value());
+ } else {
+ VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context
+ << ") -- Someone else recovered " << mServiceName;
+ // Might still have a transport error, which we need to check
+ // before pingReturn goes out of scope.
+ (void)pingReturn.isOk();
+ }
+ std::tie(device, deathHandler) = mCore.getDeviceAndDeathHandler<T_IDevice>();
+ }
+ ret = callProtected(context, fn, device, callback, deathHandler);
+ // It's possible that the device died again, but we're only going to
+ // attempt recovery once per call to recoverable().
+ }
+ return ret;
}
std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() {
const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
std::pair<ErrorStatus, Capabilities> result;
- if (mDeviceV1_2 != nullptr) {
+ if (getDevice<V1_2::IDevice>() != nullptr) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_2");
- Return<void> ret = mDeviceV1_2->getCapabilities_1_2(
- [&result](ErrorStatus error, const Capabilities& capabilities) {
- result = std::make_pair(error, capabilities);
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getCapabilities_1_2(
+ [&result](ErrorStatus error, const Capabilities& capabilities) {
+ result = std::make_pair(error, capabilities);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getCapabilities_1_2 failure: " << ret.description();
return {ErrorStatus::GENERAL_FAILURE, {}};
}
- } else if (mDeviceV1_1 != nullptr) {
+ } else if (getDevice<V1_1::IDevice>() != nullptr) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_1");
- Return<void> ret = mDeviceV1_1->getCapabilities_1_1(
- [&result](ErrorStatus error, const V1_1::Capabilities& capabilities) {
- // Time taken to convert capabilities is trivial
- result = std::make_pair(error, convertToV1_2(capabilities));
+ Return<void> ret = recoverable<void, V1_1::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_1::IDevice>& device) {
+ return device->getCapabilities_1_1(
+ [&result](ErrorStatus error, const V1_1::Capabilities& capabilities) {
+ // Time taken to convert capabilities is trivial
+ result = std::make_pair(error, convertToV1_2(capabilities));
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getCapabilities_1_1 failure: " << ret.description();
return kFailure;
}
- } else if (mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_0::IDevice>() != nullptr) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities");
- Return<void> ret = mDeviceV1_0->getCapabilities(
- [&result](ErrorStatus error, const V1_0::Capabilities& capabilities) {
- // Time taken to convert capabilities is trivial
- result = std::make_pair(error, convertToV1_2(capabilities));
+ Return<void> ret = recoverable<void, V1_0::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_0::IDevice>& device) {
+ return device->getCapabilities(
+ [&result](ErrorStatus error, const V1_0::Capabilities& capabilities) {
+ // Time taken to convert capabilities is trivial
+ result = std::make_pair(error, convertToV1_2(capabilities));
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getCapabilities failure: " << ret.description();
@@ -309,18 +461,21 @@ std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() {
std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() {
const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedExtensions");
- if (mDeviceV1_2 != nullptr) {
+ if (getDevice<V1_2::IDevice>() != nullptr) {
std::pair<ErrorStatus, hidl_vec<Extension>> result;
- Return<void> ret = mDeviceV1_2->getSupportedExtensions(
- [&result](ErrorStatus error, const hidl_vec<Extension>& extensions) {
- result = std::make_pair(error, extensions);
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getSupportedExtensions(
+ [&result](ErrorStatus error, const hidl_vec<Extension>& extensions) {
+ result = std::make_pair(error, extensions);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getSupportedExtensions failure: " << ret.description();
return kFailure;
}
return result;
- } else if (mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_0::IDevice>() != nullptr) {
return {ErrorStatus::NONE, {/* No extensions. */}};
} else {
LOG(ERROR) << "Device not available!";
@@ -354,11 +509,14 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
return std::make_pair(status, std::move(remappedSupported));
};
- if (mDeviceV1_2 != nullptr) {
+ if (getDevice<V1_2::IDevice>() != nullptr) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_2");
- Return<void> ret = mDeviceV1_2->getSupportedOperations_1_2(
- model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
- result = std::make_pair(error, supported);
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&model, &result](const sp<V1_2::IDevice>& device) {
+ return device->getSupportedOperations_1_2(
+ model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
+ result = std::make_pair(error, supported);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getSupportedOperations_1_2 failure: " << ret.description();
@@ -367,7 +525,7 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
return result;
}
- if (mDeviceV1_1 != nullptr) {
+ if (getDevice<V1_1::IDevice>() != nullptr) {
const bool compliant = compliantWithV1_1(model);
if (compliant || slicer) {
V1_1::Model model11;
@@ -383,9 +541,13 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
}
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION,
"getSupportedOperations_1_1");
- Return<void> ret = mDeviceV1_1->getSupportedOperations_1_1(
- model11, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
- result = std::make_pair(error, supported);
+ Return<void> ret = recoverable<void, V1_1::IDevice>(
+ __FUNCTION__, [&model11, &result](const sp<V1_1::IDevice>& device) {
+ return device->getSupportedOperations_1_1(
+ model11,
+ [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
+ result = std::make_pair(error, supported);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getSupportedOperations_1_1 failure: " << ret.description();
@@ -398,7 +560,7 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
return result;
}
- if (mDeviceV1_0 != nullptr) {
+ if (getDevice<V1_0::IDevice>() != nullptr) {
const bool compliant = compliantWithV1_0(model);
if (compliant || slicer) {
V1_0::Model model10;
@@ -413,9 +575,13 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
std::tie(model10, submodelOperationIndexToModelOperationIndex) = *slice10;
}
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations");
- Return<void> ret = mDeviceV1_0->getSupportedOperations(
- model10, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
- result = std::make_pair(error, supported);
+ Return<void> ret = recoverable<void, V1_0::IDevice>(
+ __FUNCTION__, [&model10, &result](const sp<V1_0::IDevice>& device) {
+ return device->getSupportedOperations(
+ model10,
+ [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
+ result = std::make_pair(error, supported);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getSupportedOperations failure: " << ret.description();
@@ -443,12 +609,16 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
return kFailure;
}
- const auto scoped = mDeathHandler->protectCallback(callback);
-
// If 1.2 device, try preparing model
- if (mDeviceV1_2 != nullptr) {
- const Return<ErrorStatus> ret = mDeviceV1_2->prepareModel_1_2(model, preference, modelCache,
- dataCache, token, callback);
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>(
+ __FUNCTION__,
+ [&model, &preference, &modelCache, &dataCache, &token,
+ &callback](const sp<V1_2::IDevice>& device) {
+ return device->prepareModel_1_2(model, preference, modelCache, dataCache, token,
+ callback);
+ },
+ callback);
if (!ret.isOk()) {
LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
return kFailure;
@@ -462,7 +632,7 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
}
// If 1.1 device, try preparing model (requires conversion)
- if (mDeviceV1_1 != nullptr) {
+ if (getDevice<V1_1::IDevice>() != nullptr) {
bool compliant = false;
V1_1::Model model11;
{
@@ -477,8 +647,12 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
}
}
if (compliant) {
- const Return<ErrorStatus> ret =
- mDeviceV1_1->prepareModel_1_1(model11, preference, callback);
+ const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_1::IDevice>(
+ __FUNCTION__,
+ [&model11, &preference, &callback](const sp<V1_1::IDevice>& device) {
+ return device->prepareModel_1_1(model11, preference, callback);
+ },
+ callback);
if (!ret.isOk()) {
LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
return kFailure;
@@ -493,14 +667,12 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
makeVersionedIPreparedModel(callback->getPreparedModel())};
}
- // TODO: partition the model such that v1.2 ops are not passed to v1.1
- // device
LOG(ERROR) << "Could not handle prepareModel_1_1!";
return kFailure;
}
// If 1.0 device, try preparing model (requires conversion)
- if (mDeviceV1_0 != nullptr) {
+ if (getDevice<V1_0::IDevice>() != nullptr) {
bool compliant = false;
V1_0::Model model10;
{
@@ -515,7 +687,12 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
}
}
if (compliant) {
- const Return<ErrorStatus> ret = mDeviceV1_0->prepareModel(model10, callback);
+ const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_0::IDevice>(
+ __FUNCTION__,
+ [&model10, &callback](const sp<V1_0::IDevice>& device) {
+ return device->prepareModel(model10, callback);
+ },
+ callback);
if (!ret.isOk()) {
LOG(ERROR) << "prepareModel failure: " << ret.description();
return kFailure;
@@ -529,8 +706,6 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
makeVersionedIPreparedModel(callback->getPreparedModel())};
}
- // TODO: partition the model such that v1.1 ops are not passed to v1.0
- // device
LOG(ERROR) << "Could not handle prepareModel!";
return kFailure;
}
@@ -553,11 +728,13 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
return kFailure;
}
- const auto scoped = mDeathHandler->protectCallback(callback);
-
- if (mDeviceV1_2 != nullptr) {
- const Return<ErrorStatus> ret =
- mDeviceV1_2->prepareModelFromCache(modelCache, dataCache, token, callback);
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>(
+ __FUNCTION__,
+ [&modelCache, &dataCache, &token, &callback](const sp<V1_2::IDevice>& device) {
+ return device->prepareModelFromCache(modelCache, dataCache, token, callback);
+ },
+ callback);
if (!ret.isOk()) {
LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
return kFailure;
@@ -571,7 +748,7 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())};
}
- if (mDeviceV1_1 != nullptr || mDeviceV1_0 != nullptr) {
+ if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
LOG(ERROR) << "prepareModelFromCache called on V1_1 or V1_0 device";
return kFailure;
}
@@ -581,12 +758,13 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
}
DeviceStatus VersionedIDevice::getStatus() {
- if (mDeviceV1_0 == nullptr) {
+ if (getDevice<V1_0::IDevice>() == nullptr) {
LOG(ERROR) << "Device not available!";
return DeviceStatus::UNKNOWN;
}
- Return<DeviceStatus> ret = mDeviceV1_0->getStatus();
+ Return<DeviceStatus> ret = recoverable<DeviceStatus, V1_0::IDevice>(
+ __FUNCTION__, [](const sp<V1_0::IDevice>& device) { return device->getStatus(); });
if (!ret.isOk()) {
LOG(ERROR) << "getStatus failure: " << ret.description();
@@ -598,11 +776,11 @@ DeviceStatus VersionedIDevice::getStatus() {
int64_t VersionedIDevice::getFeatureLevel() {
constexpr int64_t kFailure = -1;
- if (mDeviceV1_2 != nullptr) {
+ if (getDevice<V1_2::IDevice>() != nullptr) {
return __ANDROID_API_Q__;
- } else if (mDeviceV1_1 != nullptr) {
+ } else if (getDevice<V1_1::IDevice>() != nullptr) {
return __ANDROID_API_P__;
- } else if (mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_0::IDevice>() != nullptr) {
return __ANDROID_API_O_MR1__;
} else {
LOG(ERROR) << "Device not available!";
@@ -614,10 +792,12 @@ int32_t VersionedIDevice::getType() const {
constexpr int32_t kFailure = -1;
std::pair<ErrorStatus, DeviceType> result;
- if (mDeviceV1_2 != nullptr) {
- Return<void> ret =
- mDeviceV1_2->getType([&result](ErrorStatus error, DeviceType deviceType) {
- result = std::make_pair(error, deviceType);
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getType([&result](ErrorStatus error, DeviceType deviceType) {
+ result = std::make_pair(error, deviceType);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getType failure: " << ret.description();
@@ -634,17 +814,20 @@ std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() {
const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""};
std::pair<ErrorStatus, hidl_string> result;
- if (mDeviceV1_2 != nullptr) {
- Return<void> ret = mDeviceV1_2->getVersionString(
- [&result](ErrorStatus error, const hidl_string& version) {
- result = std::make_pair(error, version);
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getVersionString(
+ [&result](ErrorStatus error, const hidl_string& version) {
+ result = std::make_pair(error, version);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getVersion failure: " << ret.description();
return kFailure;
}
return result;
- } else if (mDeviceV1_1 != nullptr || mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
return {ErrorStatus::NONE, "UNKNOWN"};
} else {
LOG(ERROR) << "Could not handle getVersionString";
@@ -657,17 +840,21 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi
0, 0};
std::tuple<ErrorStatus, uint32_t, uint32_t> result;
- if (mDeviceV1_2 != nullptr) {
- Return<void> ret = mDeviceV1_2->getNumberOfCacheFilesNeeded(
- [&result](ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
- result = {error, numModelCache, numDataCache};
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getNumberOfCacheFilesNeeded([&result](ErrorStatus error,
+ uint32_t numModelCache,
+ uint32_t numDataCache) {
+ result = {error, numModelCache, numDataCache};
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getNumberOfCacheFilesNeeded failure: " << ret.description();
return kFailure;
}
return result;
- } else if (mDeviceV1_1 != nullptr || mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
return {ErrorStatus::NONE, 0, 0};
} else {
LOG(ERROR) << "Could not handle getNumberOfCacheFilesNeeded";
@@ -676,11 +863,11 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi
}
bool VersionedIDevice::operator==(nullptr_t) const {
- return mDeviceV1_0 == nullptr;
+ return getDevice<V1_0::IDevice>() == nullptr;
}
bool VersionedIDevice::operator!=(nullptr_t) const {
- return mDeviceV1_0 != nullptr;
+ return getDevice<V1_0::IDevice>() != nullptr;
}
} // namespace nn
diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h
index 81a66013f..f0aaf6d6e 100644
--- a/nn/runtime/VersionedInterfaces.h
+++ b/nn/runtime/VersionedInterfaces.h
@@ -20,9 +20,14 @@
#include "HalInterfaces.h"
#include <android-base/macros.h>
+#include <cstddef>
+#include <functional>
#include <memory>
+#include <optional>
+#include <shared_mutex>
#include <string>
#include <tuple>
+#include <utility>
#include "Callbacks.h"
namespace android {
@@ -53,6 +58,9 @@ class VersionedIPreparedModel;
class VersionedIDevice {
DISALLOW_IMPLICIT_CONSTRUCTORS(VersionedIDevice);
+ // forward declaration of nested class
+ class Core;
+
public:
/**
* Create a VersionedIDevice object.
@@ -60,40 +68,26 @@ class VersionedIDevice {
* Prefer using this function over the constructor, as it adds more
* protections.
*
- * This call linksToDeath a hidl_death_recipient that can
- * proactively handle the case when the service containing the IDevice
- * object crashes.
- *
+ * @param serviceName The name of the service that provides "device".
* @param device A device object that is at least version 1.0 of the IDevice
* interface.
* @return A valid VersionedIDevice object, otherwise nullptr.
*/
- static std::shared_ptr<VersionedIDevice> create(sp<V1_0::IDevice> device);
+ static std::shared_ptr<VersionedIDevice> create(std::string serviceName,
+ sp<V1_0::IDevice> device);
/**
* Constructor for the VersionedIDevice object.
*
- * VersionedIDevice is constructed with the V1_0::IDevice object, which
- * represents a device that is at least v1.0 of the interface. The
- * constructor downcasts to the latest version of the IDevice interface, and
- * will default to using the latest version of all IDevice interface
- * methods automatically.
- *
- * @param device A device object that is at least version 1.0 of the IDevice
- * interface.
- * @param deathHandler A hidl_death_recipient that will proactively handle
- * the case when the service containing the IDevice
- * object crashes.
- */
- VersionedIDevice(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler);
-
- /**
- * Destructor for the VersionedIDevice object.
+ * VersionedIDevice will default to using the latest version of all IDevice
+ * interface methods automatically.
*
- * This destructor unlinksToDeath this object's hidl_death_recipient as it
- * no longer needs to handle the case where the IDevice's service crashes.
+ * @param serviceName The name of the service that provides core.getDevice<V1_0::IDevice>().
+ * @param core An object that encapsulates a V1_0::IDevice, any appropriate downcasts to
+ * newer interfaces, and a hidl_death_recipient that will proactively handle
+ * the case when the service containing the IDevice object crashes.
*/
- ~VersionedIDevice();
+ VersionedIDevice(std::string serviceName, Core core);
/**
* Gets the capabilities of a driver.
@@ -426,6 +420,13 @@ class VersionedIDevice {
std::tuple<ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeeded();
/**
+ * Returns the name of the service that implements the driver
+ *
+ * @return serviceName The name of the service.
+ */
+ std::string getServiceName() const { return mServiceName; }
+
+ /**
* Returns whether this handle to an IDevice object is valid or not.
*
* @return bool true if V1_0::IDevice (which could be V1_1::IDevice) is
@@ -443,33 +444,165 @@ class VersionedIDevice {
private:
/**
- * All versions of IDevice are necessary because the driver could be v1.0,
- * v1.1, or a later version. All these pointers logically represent the same
- * object.
- *
- * The general strategy is: HIDL returns a V1_0 device object, which
- * (if not nullptr) could be v1.0, v1.1, or a greater version. The V1_0
- * object is then "dynamically cast" to a V1_1 object. If successful,
- * mDeviceV1_1 will point to the same object as mDeviceV1_0; otherwise,
- * mDeviceV1_1 will be nullptr.
- *
- * In general:
- * * If the device is truly v1.0, mDeviceV1_0 will point to a valid object
- * and mDeviceV1_1 will be nullptr.
- * * If the device is truly v1.1 or later, both mDeviceV1_0 and mDeviceV1_1
- * will point to the same valid object.
- *
- * Idiomatic usage: if mDeviceV1_1 is non-null, do V1_1 dispatch; otherwise,
- * do V1_0 dispatch.
- */
- sp<V1_0::IDevice> mDeviceV1_0;
- sp<V1_1::IDevice> mDeviceV1_1;
- sp<V1_2::IDevice> mDeviceV1_2;
-
- /**
- * HIDL callback to be invoked if the service for mDeviceV1_0 crashes.
+ * This is a utility class for VersionedIDevice that encapsulates a
+ * V1_0::IDevice, any appropriate downcasts to newer interfaces, and a
+ * hidl_death_recipient that will proactively handle the case when the
+ * service containing the IDevice object crashes.
+ *
+ * This is a convenience class to help VersionedIDevice recover from an
+ * IDevice object crash: It bundles together all the data that needs to
+ * change when recovering from a crash, and simplifies the process of
+ * instantiating that data (at VersionedIDevice creation time) and
+ * re-instantiating that data (at crash recovery time).
*/
- const sp<IDeviceDeathHandler> mDeathHandler;
+ class Core {
+ public:
+ /**
+ * Constructor for the Core object.
+ *
+ * Core is constructed with a V1_0::IDevice object, which represents a
+ * device that is at least v1.0 of the interface. The constructor
+ * downcasts to the latest version of the IDevice interface, allowing
+ * VersionedIDevice to default to using the latest version of all
+ * IDevice interface methods automatically.
+ *
+ * @param device A device object that is at least version 1.0 of the IDevice
+ * interface.
+ * @param deathHandler A hidl_death_recipient that will proactively handle
+ * the case when the service containing the IDevice
+ * object crashes.
+ */
+ Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler);
+
+ /**
+ * Destructor for the Core object.
+ *
+ * This destructor unlinksToDeath this object's hidl_death_recipient as it
+ * no longer needs to handle the case where the IDevice's service crashes.
+ */
+ ~Core();
+
+ // Support move but not copy
+ Core(Core&&) noexcept;
+ Core& operator=(Core&&) noexcept;
+ Core(const Core&) = delete;
+ Core& operator=(const Core&) = delete;
+
+ /**
+ * Create a Core object.
+ *
+ * Prefer using this function over the constructor, as it adds more
+ * protections.
+ *
+ * This call linksToDeath a hidl_death_recipient that can
+ * proactively handle the case when the service containing the IDevice
+ * object crashes.
+ *
+ * @param device A device object that is at least version 1.0 of the IDevice
+ * interface.
+ * @return A valid Core object, otherwise nullopt.
+ */
+ static std::optional<Core> create(sp<V1_0::IDevice> device);
+
+ /**
+ * Returns sp<*::IDevice> that is a downcast of the sp<V1_0::IDevice>
+ * passed to the constructor. This will be nullptr if that IDevice is
+ * not actually of the specified downcast type.
+ */
+ template <typename T_IDevice>
+ sp<T_IDevice> getDevice() const;
+ template <>
+ sp<V1_0::IDevice> getDevice() const {
+ return mDeviceV1_0;
+ }
+ template <>
+ sp<V1_1::IDevice> getDevice() const {
+ return mDeviceV1_1;
+ }
+ template <>
+ sp<V1_2::IDevice> getDevice() const {
+ return mDeviceV1_2;
+ }
+
+ /**
+ * Returns sp<*::IDevice> (as per getDevice()) and the
+ * hidl_death_recipient that will proactively handle the case when the
+ * service containing the IDevice object crashes.
+ */
+ template <typename T_IDevice>
+ std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> getDeviceAndDeathHandler() const;
+
+ private:
+ /**
+ * All versions of IDevice are necessary because the driver could be v1.0,
+ * v1.1, or a later version. All these pointers logically represent the same
+ * object.
+ *
+ * The general strategy is: HIDL returns a V1_0 device object, which
+ * (if not nullptr) could be v1.0, v1.1, or a greater version. The V1_0
+ * object is then "dynamically cast" to a V1_1 object. If successful,
+ * mDeviceV1_1 will point to the same object as mDeviceV1_0; otherwise,
+ * mDeviceV1_1 will be nullptr.
+ *
+ * In general:
+ * * If the device is truly v1.0, mDeviceV1_0 will point to a valid object
+ * and mDeviceV1_1 will be nullptr.
+ * * If the device is truly v1.1 or later, both mDeviceV1_0 and mDeviceV1_1
+ * will point to the same valid object.
+ *
+ * Idiomatic usage: if mDeviceV1_1 is non-null, do V1_1 dispatch; otherwise,
+ * do V1_0 dispatch.
+ */
+ sp<V1_0::IDevice> mDeviceV1_0;
+ sp<V1_1::IDevice> mDeviceV1_1;
+ sp<V1_2::IDevice> mDeviceV1_2;
+
+ /**
+ * HIDL callback to be invoked if the service for mDeviceV1_0 crashes.
+ *
+ * nullptr if this Core instance is a move victim and hence has no
+ * callback to be unlinked.
+ */
+ sp<IDeviceDeathHandler> mDeathHandler;
+ };
+
+ // This method retrieves the appropriate mCore.mDevice* field, under a read lock.
+ template <typename T_IDevice>
+ sp<T_IDevice> getDevice() const EXCLUDES(mMutex) {
+ std::shared_lock lock(mMutex);
+ return mCore.getDevice<T_IDevice>();
+ }
+
+ // This method retrieves the appropriate mCore.mDevice* fields, under a read lock.
+ template <typename T_IDevice>
+ auto getDeviceAndDeathHandler() const EXCLUDES(mMutex) {
+ std::shared_lock lock(mMutex);
+ return mCore.getDeviceAndDeathHandler<T_IDevice>();
+ }
+
+ // This method calls the function fn in a manner that supports recovering
+ // from a driver crash: If the driver implementation is dead because the
+ // driver crashed either before the call to fn or during the call to fn, we
+ // will attempt to obtain a new instance of the same driver and call fn
+ // again.
+ //
+ // If a callback is provided, this method protects it against driver death
+ // and waits for it (callback->wait()).
+ template <typename T_Return, typename T_IDevice, typename T_Callback = std::nullptr_t>
+ Return<T_Return> recoverable(const char* context,
+ const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
+ const T_Callback& callback = nullptr) const EXCLUDES(mMutex);
+
+ // The name of the service that implements the driver.
+ const std::string mServiceName;
+
+ // Guards access to mCore.
+ mutable std::shared_mutex mMutex;
+
+ // Data that can be rewritten during driver recovery. Guarded againt
+ // synchronous access by a mutex: Any number of concurrent read accesses is
+ // permitted, but a write access excludes all other accesses.
+ mutable Core mCore GUARDED_BY(mMutex);
};
/** This class wraps an IPreparedModel object of any version. */