summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2019-05-22 11:36:58 -0700
committerMichael Butler <butlermichael@google.com>2019-05-24 00:18:59 +0000
commit7d1ea8b70c42e0a4182a78630f032af0cee0bfb0 (patch)
treeaec60c55c0935c519e0f0d1651e33eb2d1a03b0d
parentc8747bb09bd63bf7d4e01bd4625de05cd83bb6f8 (diff)
downloadml-7d1ea8b70c42e0a4182a78630f032af0cee0bfb0.tar.gz
Prevent hang for asynchronous calls when transport failure occurs
Prior to this CL, asynchronous calls were protected with the following usage pattern: (1) the callback object is registered for protection (2) the asynchronous call that uses the callback object is invoked (3) callback->wait() is called to wait for the asynchronous results (4) the error status of launching the call is checked (5) the callback object is unregistered for protection when leaving scope However, if a transport error occured when launching the asynchronous execution (e.g., the data being sent across HIDL exceeds a preset limit, resulting in a transport error), the code will continue waiting at (3) before it can check the launch error in (4). This CL fixes this by checking the launch status before waiting for the results. Additionally, because VersionedIPreparedModel::execute takes in a callback as an argument from the caller, extra protections are put in place to notify the callback in the event that the asynchronous call could not be launched because of unexpected behavior from the driver or internal problems in the runtime. Bug: 133325508 Bug: 118624080 Test: mma Test: NeuralNetworksTest_static Test: CtsNNAPITestCases Test: ran "NeuralNetworksTest_static --gtest_filter=GeneratedTests.add", killed the sample-minimal driver, and confirmed (1) that the runtime was not blocked and (2) that the appropriate log message was recorded. NOTE: this was facilitated by adding a 10 second sleep in the sample driver for the asynchronous preparation and asynchronous execution, enabling the service to be manually killed via "adb shell kill -9 <pid>". Test: ran "NeuralNetworksTest_static --gtest_filter=GeneratedTests.add", with local modifications to the sample driver to have it return an error message for launching an asynchronous call, but not make the corresponding call to callback->notify. Ensured the runtime still progressed and the appropriate messages were logged. Confirmed that without the changes in this CL, a hang occurs. Test: ran "NeuralNetworksTest_static --gtest_filter=GeneratedTests.add", with local modifications to the runtime to sleep for 10 seconds before calling an asynchronous call. In this window, the sample-minimal driver was manually killed (via "adb shell kill -9 <pid>"), prompting a transport failure. Ensured the runtime still progressed and the appropriate messages were logged. Change-Id: Ic4e8cc8399b1e30fadfaf01842ce62550ad2223f
-rw-r--r--nn/runtime/VersionedInterfaces.cpp29
1 files changed, 23 insertions, 6 deletions
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index c72c7d822..d0439b213 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -44,6 +44,10 @@ void sendFailureMessage(const sp<IExecutionCallback>& cb) {
cb->notify(ErrorStatus::GENERAL_FAILURE);
}
+void sendFailureMessage(const sp<ExecutionCallback>& cb) {
+ sendFailureMessage(static_cast<sp<IExecutionCallback>>(cb));
+}
+
// This class is thread safe
template <typename ICallback>
class DeathHandler : public hardware::hidl_death_recipient {
@@ -133,21 +137,34 @@ ErrorStatus VersionedIPreparedModel::execute(const Request& request, MeasureTimi
if (mPreparedModelV1_2 != nullptr) {
Return<ErrorStatus> ret = mPreparedModelV1_2->execute_1_2(request, measure, callback);
- callback->wait();
if (!ret.isOk()) {
+ sendFailureMessage(callback);
LOG(ERROR) << "execute_1_2 failure: " << ret.description();
return ErrorStatus::GENERAL_FAILURE;
}
+ if (ret != ErrorStatus::NONE) {
+ sendFailureMessage(callback);
+ LOG(ERROR) << "execute_1_2 returned " << toString(static_cast<ErrorStatus>(ret));
+ return static_cast<ErrorStatus>(ret);
+ }
+ callback->wait();
return static_cast<ErrorStatus>(ret);
} else if (mPreparedModelV1_0 != nullptr) {
Return<ErrorStatus> ret = mPreparedModelV1_0->execute(request, callback);
- callback->wait();
if (!ret.isOk()) {
+ sendFailureMessage(callback);
LOG(ERROR) << "execute failure: " << ret.description();
return ErrorStatus::GENERAL_FAILURE;
}
+ if (ret != ErrorStatus::NONE) {
+ sendFailureMessage(callback);
+ LOG(ERROR) << "execute returned " << toString(static_cast<ErrorStatus>(ret));
+ return static_cast<ErrorStatus>(ret);
+ }
+ callback->wait();
return static_cast<ErrorStatus>(ret);
} else {
+ sendFailureMessage(callback);
LOG(ERROR) << "execute called with no preparedModel";
return ErrorStatus::GENERAL_FAILURE;
}
@@ -432,7 +449,6 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
if (mDeviceV1_2 != nullptr) {
const Return<ErrorStatus> ret = mDeviceV1_2->prepareModel_1_2(model, preference, modelCache,
dataCache, token, callback);
- callback->wait();
if (!ret.isOk()) {
LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
return kFailure;
@@ -441,6 +457,7 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
LOG(ERROR) << "prepareModel_1_2 returned " << toString(static_cast<ErrorStatus>(ret));
return kFailure;
}
+ callback->wait();
return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())};
}
@@ -462,7 +479,6 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
if (compliant) {
const Return<ErrorStatus> ret =
mDeviceV1_1->prepareModel_1_1(model11, preference, callback);
- callback->wait();
if (!ret.isOk()) {
LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
return kFailure;
@@ -472,6 +488,7 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
<< toString(static_cast<ErrorStatus>(ret));
return kFailure;
}
+ callback->wait();
return {callback->getStatus(),
makeVersionedIPreparedModel(callback->getPreparedModel())};
}
@@ -499,7 +516,6 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
}
if (compliant) {
const Return<ErrorStatus> ret = mDeviceV1_0->prepareModel(model10, callback);
- callback->wait();
if (!ret.isOk()) {
LOG(ERROR) << "prepareModel failure: " << ret.description();
return kFailure;
@@ -508,6 +524,7 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
LOG(ERROR) << "prepareModel returned " << toString(static_cast<ErrorStatus>(ret));
return kFailure;
}
+ callback->wait();
return {callback->getStatus(),
makeVersionedIPreparedModel(callback->getPreparedModel())};
}
@@ -541,7 +558,6 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
if (mDeviceV1_2 != nullptr) {
const Return<ErrorStatus> ret =
mDeviceV1_2->prepareModelFromCache(modelCache, dataCache, token, callback);
- callback->wait();
if (!ret.isOk()) {
LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
return kFailure;
@@ -551,6 +567,7 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
<< toString(static_cast<ErrorStatus>(ret));
return kFailure;
}
+ callback->wait();
return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())};
}