diff options
author | Michael Butler <butlermichael@google.com> | 2019-05-22 11:36:58 -0700 |
---|---|---|
committer | Michael Butler <butlermichael@google.com> | 2019-05-24 00:18:59 +0000 |
commit | 7d1ea8b70c42e0a4182a78630f032af0cee0bfb0 (patch) | |
tree | aec60c55c0935c519e0f0d1651e33eb2d1a03b0d | |
parent | c8747bb09bd63bf7d4e01bd4625de05cd83bb6f8 (diff) | |
download | ml-7d1ea8b70c42e0a4182a78630f032af0cee0bfb0.tar.gz |
Prevent hang for asynchronous calls when transport failure occurs
Prior to this CL, asynchronous calls were protected with the following
usage pattern:
(1) the callback object is registered for protection
(2) the asynchronous call that uses the callback object is invoked
(3) callback->wait() is called to wait for the asynchronous results
(4) the error status of launching the call is checked
(5) the callback object is unregistered for protection when leaving
scope
However, if a transport error occured when launching the asynchronous
execution (e.g., the data being sent across HIDL exceeds a preset
limit, resulting in a transport error), the code will continue waiting
at (3) before it can check the launch error in (4).
This CL fixes this by checking the launch status before waiting for the
results. Additionally, because VersionedIPreparedModel::execute takes in
a callback as an argument from the caller, extra protections are put in
place to notify the callback in the event that the asynchronous call
could not be launched because of unexpected behavior from the driver or
internal problems in the runtime.
Bug: 133325508
Bug: 118624080
Test: mma
Test: NeuralNetworksTest_static
Test: CtsNNAPITestCases
Test: ran "NeuralNetworksTest_static --gtest_filter=GeneratedTests.add",
killed the sample-minimal driver, and confirmed (1) that the runtime
was not blocked and (2) that the appropriate log message was recorded.
NOTE: this was facilitated by adding a 10 second sleep in the sample
driver for the asynchronous preparation and asynchronous execution,
enabling the service to be manually killed via
"adb shell kill -9 <pid>".
Test: ran "NeuralNetworksTest_static --gtest_filter=GeneratedTests.add",
with local modifications to the sample driver to have it return an
error message for launching an asynchronous call, but not make the
corresponding call to callback->notify. Ensured the runtime still
progressed and the appropriate messages were logged. Confirmed that
without the changes in this CL, a hang occurs.
Test: ran "NeuralNetworksTest_static --gtest_filter=GeneratedTests.add",
with local modifications to the runtime to sleep for 10 seconds before
calling an asynchronous call. In this window, the sample-minimal
driver was manually killed (via "adb shell kill -9 <pid>"), prompting
a transport failure. Ensured the runtime still progressed and the
appropriate messages were logged.
Change-Id: Ic4e8cc8399b1e30fadfaf01842ce62550ad2223f
-rw-r--r-- | nn/runtime/VersionedInterfaces.cpp | 29 |
1 files changed, 23 insertions, 6 deletions
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index c72c7d822..d0439b213 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -44,6 +44,10 @@ void sendFailureMessage(const sp<IExecutionCallback>& cb) { cb->notify(ErrorStatus::GENERAL_FAILURE); } +void sendFailureMessage(const sp<ExecutionCallback>& cb) { + sendFailureMessage(static_cast<sp<IExecutionCallback>>(cb)); +} + // This class is thread safe template <typename ICallback> class DeathHandler : public hardware::hidl_death_recipient { @@ -133,21 +137,34 @@ ErrorStatus VersionedIPreparedModel::execute(const Request& request, MeasureTimi if (mPreparedModelV1_2 != nullptr) { Return<ErrorStatus> ret = mPreparedModelV1_2->execute_1_2(request, measure, callback); - callback->wait(); if (!ret.isOk()) { + sendFailureMessage(callback); LOG(ERROR) << "execute_1_2 failure: " << ret.description(); return ErrorStatus::GENERAL_FAILURE; } + if (ret != ErrorStatus::NONE) { + sendFailureMessage(callback); + LOG(ERROR) << "execute_1_2 returned " << toString(static_cast<ErrorStatus>(ret)); + return static_cast<ErrorStatus>(ret); + } + callback->wait(); return static_cast<ErrorStatus>(ret); } else if (mPreparedModelV1_0 != nullptr) { Return<ErrorStatus> ret = mPreparedModelV1_0->execute(request, callback); - callback->wait(); if (!ret.isOk()) { + sendFailureMessage(callback); LOG(ERROR) << "execute failure: " << ret.description(); return ErrorStatus::GENERAL_FAILURE; } + if (ret != ErrorStatus::NONE) { + sendFailureMessage(callback); + LOG(ERROR) << "execute returned " << toString(static_cast<ErrorStatus>(ret)); + return static_cast<ErrorStatus>(ret); + } + callback->wait(); return static_cast<ErrorStatus>(ret); } else { + sendFailureMessage(callback); LOG(ERROR) << "execute called with no preparedModel"; return ErrorStatus::GENERAL_FAILURE; } @@ -432,7 +449,6 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic if (mDeviceV1_2 != nullptr) { const Return<ErrorStatus> ret = mDeviceV1_2->prepareModel_1_2(model, preference, modelCache, dataCache, token, callback); - callback->wait(); if (!ret.isOk()) { LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description(); return kFailure; @@ -441,6 +457,7 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic LOG(ERROR) << "prepareModel_1_2 returned " << toString(static_cast<ErrorStatus>(ret)); return kFailure; } + callback->wait(); return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())}; } @@ -462,7 +479,6 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic if (compliant) { const Return<ErrorStatus> ret = mDeviceV1_1->prepareModel_1_1(model11, preference, callback); - callback->wait(); if (!ret.isOk()) { LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description(); return kFailure; @@ -472,6 +488,7 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic << toString(static_cast<ErrorStatus>(ret)); return kFailure; } + callback->wait(); return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())}; } @@ -499,7 +516,6 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic } if (compliant) { const Return<ErrorStatus> ret = mDeviceV1_0->prepareModel(model10, callback); - callback->wait(); if (!ret.isOk()) { LOG(ERROR) << "prepareModel failure: " << ret.description(); return kFailure; @@ -508,6 +524,7 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic LOG(ERROR) << "prepareModel returned " << toString(static_cast<ErrorStatus>(ret)); return kFailure; } + callback->wait(); return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())}; } @@ -541,7 +558,6 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache, if (mDeviceV1_2 != nullptr) { const Return<ErrorStatus> ret = mDeviceV1_2->prepareModelFromCache(modelCache, dataCache, token, callback); - callback->wait(); if (!ret.isOk()) { LOG(ERROR) << "prepareModelFromCache failure: " << ret.description(); return kFailure; @@ -551,6 +567,7 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache, << toString(static_cast<ErrorStatus>(ret)); return kFailure; } + callback->wait(); return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())}; } |