diff options
author | Michael Butler <butlermichael@google.com> | 2019-02-25 18:05:44 -0800 |
---|---|---|
committer | Michael Butler <butlermichael@google.com> | 2019-03-07 15:09:05 -0800 |
commit | ae060300a5dec3f1b900d7de20b56f6e43b4e6a5 (patch) | |
tree | e755ed5b5a680c75a34df6538205a46b870a6668 /nn/runtime/Manager.cpp | |
parent | 7d91badb48517c86b995df64a909e94dfd8b5e5b (diff) | |
download | ml-ae060300a5dec3f1b900d7de20b56f6e43b4e6a5.tar.gz |
Prevent asynchronous calls from hanging when service crashes
This CL introduces hidl death recipients to prevent asynchronous calls
from hanging when a service dies. The general process is as follows:
(1) Register HIDL death recipient ("DeathHandler") when an interface
object is constructed.
(2) Register I*Callback object with DeathHandler immediately before an
asynchronous execution.
(3) Immediately wait for the callback object to be populated with either
(a) the results of a valid asynchronous execution
(b) the error code provided by the DeathHandler if the service has
died
(4) Unregister I*Callback object with DeathHandler after the callback
object contains a result.
(5) Repeat 2-4 as many times as necessary.
(6) Unregister HIDL death recipient ("DeathHandler") when the HIDL
interface object is being destroyed.
As an extra convenience and protection, #4 (callback unregistration)
happens automatically when the current scope is exiting via an
RAII-manager "ScopeGuard" object.
Fixes: 118624080
Test: ran "NeuralNetworksTest_static --gtest_filter=GeneratedTests.add",
killed the sample-all driver, and confirmed (1) that the runtime was not
blocked and (2) that the appropriate log message was recorded. NOTE: this was
facilitated by adding a 10 second sleep in the sample driver for the
asynchronous preparation and asynchronous execution, enabling the service
to be manually killed via "adb shell kill -9 <pid>".
Change-Id: Ie90de41eadfe0d8a1e71f5079962f27ae97852cb
Diffstat (limited to 'nn/runtime/Manager.cpp')
-rw-r--r-- | nn/runtime/Manager.cpp | 39 |
1 files changed, 22 insertions, 17 deletions
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp index 6c15daf08..15c928fdd 100644 --- a/nn/runtime/Manager.cpp +++ b/nn/runtime/Manager.cpp @@ -83,9 +83,9 @@ class DriverDevice : public Device { const char* getName() const override { return mName.c_str(); } const char* getVersionString() const override { return mVersionString.c_str(); } - VersionedIDevice* getInterface() override { return &mInterface; } - int64_t getFeatureLevel() override { return mInterface.getFeatureLevel(); } - int32_t getType() const override { return mInterface.getType(); } + VersionedIDevice* getInterface() override { return mInterface.get(); } + int64_t getFeatureLevel() override { return mInterface->getFeatureLevel(); } + int32_t getType() const override { return mInterface->getType(); } hidl_vec<Extension> getSupportedExtensions() const override; void getSupportedOperations(const Model& hidlModel, hidl_vec<bool>* supportedOperations) override; @@ -104,14 +104,14 @@ class DriverDevice : public Device { private: int prepareModelHelper( - const std::function<Return<ErrorStatus>(const sp<IPreparedModelCallback>& callback)>& + const std::function<Return<ErrorStatus>(const sp<PreparedModelCallback>& callback)>& prepare, const std::string& prepareName, std::shared_ptr<VersionedIPreparedModel>* preparedModel); std::string mName; std::string mVersionString; - VersionedIDevice mInterface; + const std::shared_ptr<VersionedIDevice> mInterface; PerformanceInfo mFloat32Performance; PerformanceInfo mQuantized8Performance; PerformanceInfo mRelaxedFloat32toFloat16Performance; @@ -127,7 +127,7 @@ class DriverDevice : public Device { }; DriverDevice::DriverDevice(std::string name, const sp<V1_0::IDevice>& device) - : mName(std::move(name)), mInterface(device) {} + : mName(std::move(name)), mInterface(VersionedIDevice::create(device)) {} // TODO: handle errors from initialize correctly bool DriverDevice::initialize() { @@ -141,8 +141,13 @@ bool DriverDevice::initialize() { ErrorStatus status = ErrorStatus::GENERAL_FAILURE; + if (mInterface == nullptr) { + LOG(ERROR) << "DriverDevice contains invalid interface object."; + return false; + } + Capabilities capabilities; - std::tie(status, capabilities) = mInterface.getCapabilities(); + std::tie(status, capabilities) = mInterface->getCapabilities(); if (status != ErrorStatus::NONE) { LOG(ERROR) << "IDevice::getCapabilities returned the error " << toString(status); return false; @@ -154,20 +159,20 @@ bool DriverDevice::initialize() { mQuantized8Performance = capabilities.quantized8Performance; mRelaxedFloat32toFloat16Performance = capabilities.relaxedFloat32toFloat16Performance; - std::tie(status, mVersionString) = mInterface.getVersionString(); + std::tie(status, mVersionString) = mInterface->getVersionString(); // TODO(miaowang): add a validation test case for in case of error. if (status != ErrorStatus::NONE) { LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(status); return false; } - std::tie(status, mSupportedExtensions) = mInterface.getSupportedExtensions(); + std::tie(status, mSupportedExtensions) = mInterface->getSupportedExtensions(); if (status != ErrorStatus::NONE) { LOG(ERROR) << "IDevice::getSupportedExtensions returned the error " << toString(status); return false; } - std::tie(status, mIsCachingSupported) = mInterface.isCachingSupported(); + std::tie(status, mIsCachingSupported) = mInterface->isCachingSupported(); if (status != ErrorStatus::NONE) { LOG(WARNING) << "IDevice::isCachingSupported returned the error " << toString(status); mIsCachingSupported = false; @@ -184,7 +189,7 @@ void DriverDevice::getSupportedOperations(const Model& hidlModel, // Query the driver for what it can do. ErrorStatus status = ErrorStatus::GENERAL_FAILURE; hidl_vec<bool> supportedOperations; - std::tie(status, supportedOperations) = mInterface.getSupportedOperations(hidlModel); + std::tie(status, supportedOperations) = mInterface->getSupportedOperations(hidlModel); if (status != ErrorStatus::NONE) { LOG(ERROR) << "IDevice::getSupportedOperations returned the error " << toString(status); @@ -245,7 +250,7 @@ void DriverDevice::getSupportedOperations(const Model& hidlModel, // Compilation logic copied from StepExecutor::startComputeOnDevice(). int DriverDevice::prepareModelHelper( - const std::function<Return<ErrorStatus>(const sp<IPreparedModelCallback>& callback)>& + const std::function<Return<ErrorStatus>(const sp<PreparedModelCallback>& callback)>& prepare, const std::string& prepareName, std::shared_ptr<VersionedIPreparedModel>* preparedModel) { *preparedModel = nullptr; @@ -266,7 +271,7 @@ int DriverDevice::prepareModelHelper( preparedModelCallback->wait(); ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus(); if (auto returnedPreparedModel = preparedModelCallback->getPreparedModel()) { - *preparedModel = std::make_shared<VersionedIPreparedModel>(returnedPreparedModel); + *preparedModel = VersionedIPreparedModel::create(returnedPreparedModel); } if (prepareReturnStatus != ErrorStatus::NONE || *preparedModel == nullptr) { LOG(ERROR) << prepareName << " on " << getName() << " failed:" @@ -282,8 +287,8 @@ int DriverDevice::prepareModel(const Model& hidlModel, ExecutionPreference execu // Note that some work within VersionedIDevice will be subtracted from the IPC layer NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "prepareModel"); return prepareModelHelper( - [this, &hidlModel, &executionPreference](const sp<IPreparedModelCallback>& callback) { - return mInterface.prepareModel(hidlModel, executionPreference, callback); + [this, &hidlModel, &executionPreference](const sp<PreparedModelCallback>& callback) { + return mInterface->prepareModel(hidlModel, executionPreference, callback); }, "prepareModel", preparedModel); } @@ -294,8 +299,8 @@ int DriverDevice::prepareModelFromCache(const hidl_handle& cache1, const hidl_ha // Note that some work within VersionedIDevice will be subtracted from the IPC layer NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "prepareModelFromCache"); return prepareModelHelper( - [this, &cache1, &cache2, &token](const sp<IPreparedModelCallback>& callback) { - return mInterface.prepareModelFromCache(cache1, cache2, token, callback); + [this, &cache1, &cache2, &token](const sp<PreparedModelCallback>& callback) { + return mInterface->prepareModelFromCache(cache1, cache2, token, callback); }, "prepareModelFromCache", preparedModel); } |