summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2019-05-10 18:50:02 -0700
committerMichael Butler <butlermichael@google.com>2019-05-14 16:01:34 -0700
commit75b49d08a640c10f7c70a391ea49d8c808485807 (patch)
tree676f876fee4e3bfe5431334106567a1ba1974e49
parent9ec2e391f4fe045140b80608f79e1108e37d82f7 (diff)
downloadml-75b49d08a640c10f7c70a391ea49d8c808485807.tar.gz
Protect asynchronous burst calls from hanging
Burst execution is asynchronous at the HAL layer: the runtime sends a request packet across one FMQ, then waits for the response to be received on the result FMQ. However, if the driver crashes after the request request has been made but before the result has been received, the runtime will hang. Specifically, the call to ResultChannelReceiver::getPacketBlocking will never return. This CL adds a death recipient to detect when the driver has crashed and to unblock the runtime by returning a failure. The death recipient additionally marks the sender and receiver objects as invalid, causing any subsequent calls to send or receive a packet to immediately return a failure in order to avoid future hangs. This CL additionally returns a value to the runtime to indicate whether the burst execution should be re-run using another execution path, such as IPreparedModel::executeSynchronously* or IPreparedModel::execute. ExecutionBurstController will request a re-run either when (1) the request packet failed to send across the FMQ (e.g., when the number of elements in the packet exceeded the size of the FMQ) or (2) when the burst object has been marked as invalid. Test: mma Test: ran NeuralNetworksTest_static, made the sample driver's burst execution artificially long, killed the sample driver, and confirmed (1) the runtime recovered and (2) the appropriate log messages appeared in logcat Bug: 129157135 Bug: 131086786 Change-Id: I04fcb6247dc78ea057c7596682159af1f9025235
-rw-r--r--nn/common/ExecutionBurstController.cpp126
-rw-r--r--nn/common/include/ExecutionBurstController.h54
-rw-r--r--nn/runtime/ExecutionBuilder.cpp71
3 files changed, 187 insertions, 64 deletions
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp
index be8be7105..55ef9ad29 100644
--- a/nn/common/ExecutionBurstController.cpp
+++ b/nn/common/ExecutionBurstController.cpp
@@ -34,6 +34,23 @@ using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
std::numeric_limits<uint64_t>::max()};
+class BurstContextDeathHandler : public hardware::hidl_death_recipient {
+ public:
+ using Callback = std::function<void()>;
+
+ BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
+ CHECK(onDeathCallback != nullptr);
+ }
+
+ void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
+ LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
+ mOnDeathCallback();
+ }
+
+ private:
+ const Callback mOnDeathCallback;
+};
+
} // anonymous namespace
// serialize a request into a packet
@@ -229,13 +246,13 @@ ResultChannelReceiver::getBlocking() {
}
void ResultChannelReceiver::invalidate() {
- mTeardown = true;
+ mValid = false;
// force unblock
- // ExecutionBurstServer is by default waiting on a request packet. If the
- // client process destroys its burst object, the server will still be
- // waiting on the futex (assuming mBlocking is true). This force unblock
- // wakes up any thread waiting on the futex.
+ // ExecutionBurstController waits on a result packet after sending a
+ // request. If the driver containing ExecutionBurstServer crashes, the
+ // controller will still be waiting on the futex (assuming mBlocking is
+ // true). This force unblock wakes up any thread waiting on the futex.
if (mBlocking) {
// TODO: look for a different/better way to signal/notify the futex to
// wake up any thread waiting on it
@@ -249,7 +266,7 @@ void ResultChannelReceiver::invalidate() {
std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
using discriminator = FmqResultDatum::hidl_discriminator;
- if (mTeardown) {
+ if (!mValid) {
return std::nullopt;
}
@@ -259,7 +276,7 @@ std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlock
if (mBlocking) {
success = mFmqResultChannel->readBlocking(&datum, 1);
} else {
- while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
+ while ((success = mValid.load(std::memory_order_relaxed)) &&
!mFmqResultChannel->read(&datum, 1)) {
}
}
@@ -276,8 +293,7 @@ std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlock
std::memcpy(&packet.front(), &datum, sizeof(datum));
success &= mFmqResultChannel->read(packet.data() + 1, count);
- // terminate loop
- if (mTeardown) {
+ if (!mValid) {
return std::nullopt;
}
@@ -315,6 +331,10 @@ bool RequestChannelSender::send(const Request& request, MeasureTiming measure,
}
bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
+ if (!mValid) {
+ return false;
+ }
+
if (packet.size() > mFmqRequestChannel->availableToWrite()) {
LOG(ERROR)
<< "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
@@ -328,6 +348,10 @@ bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet
}
}
+void RequestChannelSender::invalidate() {
+ mValid = false;
+}
+
Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
const hidl_vec<int32_t>& slots, getMemories_cb cb) {
std::lock_guard<std::mutex> guard(mMutex);
@@ -420,19 +444,20 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
// create callback object
sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
- if (callback == nullptr) {
- LOG(ERROR) << "ExecutionBurstController::create failed to create callback";
- return nullptr;
- }
// create FMQ objects
- auto [fmqRequestChannel, fmqRequestDescriptor] =
+ auto [requestChannelSenderTemp, requestChannelDescriptor] =
RequestChannelSender::create(kExecutionBurstChannelLength, blocking);
- auto [fmqResultChannel, fmqResultDescriptor] =
+ auto [resultChannelReceiverTemp, resultChannelDescriptor] =
ResultChannelReceiver::create(kExecutionBurstChannelLength, blocking);
+ std::shared_ptr<RequestChannelSender> requestChannelSender =
+ std::move(requestChannelSenderTemp);
+ std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
+ std::move(resultChannelReceiverTemp);
// check FMQ objects
- if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestDescriptor || !fmqResultDescriptor) {
+ if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
+ !resultChannelDescriptor) {
LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
return nullptr;
}
@@ -441,7 +466,7 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
ErrorStatus errorStatus;
sp<IBurstContext> burstContext;
const Return<void> ret = preparedModel->configureExecutionBurst(
- callback, *fmqRequestDescriptor, *fmqResultDescriptor,
+ callback, *requestChannelDescriptor, *resultChannelDescriptor,
[&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
errorStatus = status;
burstContext = context;
@@ -463,22 +488,61 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
return nullptr;
}
+ // create death handler object
+ BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
+ resultChannelReceiver] {
+ requestChannelSender->invalidate();
+ resultChannelReceiver->invalidate();
+ };
+ const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
+
+ // linkToDeath registers a callback that will be invoked on service death to
+ // proactively handle service crashes. If the linkToDeath call fails,
+ // asynchronous calls are susceptible to hangs if the service crashes before
+ // providing the response.
+ const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
+ if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
+ LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
+ "for the IBurstContext object.";
+ return nullptr;
+ }
+
// make and return controller
- return std::make_unique<ExecutionBurstController>(
- std::move(fmqRequestChannel), std::move(fmqResultChannel), burstContext, callback);
+ return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
+ burstContext, callback, deathHandler);
}
ExecutionBurstController::ExecutionBurstController(
- std::unique_ptr<RequestChannelSender> requestChannelSender,
- std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
- const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback)
- : mRequestChannelSender(std::move(requestChannelSender)),
- mResultChannelReceiver(std::move(resultChannelReceiver)),
+ const std::shared_ptr<RequestChannelSender>& requestChannelSender,
+ const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
+ const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
+ const sp<hardware::hidl_death_recipient>& deathHandler)
+ : mRequestChannelSender(requestChannelSender),
+ mResultChannelReceiver(resultChannelReceiver),
mBurstContext(burstContext),
- mMemoryCache(callback) {}
+ mMemoryCache(callback),
+ mDeathHandler(deathHandler) {}
+
+ExecutionBurstController::~ExecutionBurstController() {
+ // It is safe to ignore any errors resulting from this unlinkToDeath call
+ // because the ExecutionBurstController object is already being destroyed
+ // and its underlying IBurstContext object is no longer being used by the NN
+ // runtime.
+ if (mDeathHandler) {
+ mBurstContext->unlinkToDeath(mDeathHandler).isOk();
+ }
+}
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute(
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
+ auto [status, outputShapes, timing, fallback] = tryCompute(request, measure, memoryIds);
+ (void)fallback; // ignore fallback field
+ return {status, std::move(outputShapes), timing};
+}
+
+std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool>
+ExecutionBurstController::tryCompute(const Request& request, MeasureTiming measure,
+ const std::vector<intptr_t>& memoryIds) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
std::lock_guard<std::mutex> guard(mMutex);
@@ -488,16 +552,22 @@ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstControll
const bool success = mRequestChannelSender->send(request, measure, slots);
if (!success) {
LOG(ERROR) << "Error sending FMQ packet";
- return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
+ // only use fallback execution path if the packet could not be sent
+ return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/true};
}
// get result packet
const auto result = mResultChannelReceiver->getBlocking();
if (!result) {
LOG(ERROR) << "Error retrieving FMQ packet";
- return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
+ // only use fallback execution path if the packet could not be sent
+ return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/false};
}
- return *result;
+
+ // unpack results and return (only use fallback execution path if the
+ // packet could not be sent)
+ auto [status, outputShapes, timing] = std::move(*result);
+ return {status, std::move(outputShapes), timing, /*fallback=*/false};
}
void ExecutionBurstController::freeMemory(intptr_t key) {
diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h
index 3397d9673..030ea4d2d 100644
--- a/nn/common/include/ExecutionBurstController.h
+++ b/nn/common/include/ExecutionBurstController.h
@@ -117,7 +117,7 @@ class ResultChannelReceiver {
private:
const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
- std::atomic<bool> mTeardown{false};
+ std::atomic<bool> mValid{true};
const bool mBlocking;
};
@@ -157,6 +157,13 @@ class RequestChannelSender {
*/
bool send(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots);
+ /**
+ * Method to mark the channel as invalid, causing all future calls to
+ * RequestChannelSender::send to immediately return false without attempting
+ * to send a message across the FMQ.
+ */
+ void invalidate();
+
// prefer calling RequestChannelSender::send
bool sendPacket(const std::vector<FmqRequestDatum>& packet);
@@ -164,6 +171,7 @@ class RequestChannelSender {
private:
const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
+ std::atomic<bool> mValid{true};
const bool mBlocking;
};
@@ -259,10 +267,15 @@ class ExecutionBurstController {
static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel,
bool blocking);
- ExecutionBurstController(std::unique_ptr<RequestChannelSender> requestChannelSender,
- std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
+ // prefer calling ExecutionBurstController::create
+ ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender,
+ const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
const sp<IBurstContext>& burstContext,
- const sp<ExecutionBurstCallback>& callback);
+ const sp<ExecutionBurstCallback>& callback,
+ const sp<hardware::hidl_death_recipient>& deathHandler = nullptr);
+
+ // explicit destructor to unregister the death recipient
+ ~ExecutionBurstController();
/**
* Execute a request on a model.
@@ -271,12 +284,36 @@ class ExecutionBurstController {
* @param measure Whether to collect timing measurements, either YES or NO
* @param memoryIds Identifiers corresponding to each memory object in the
* request's pools.
- * @return status and output shape of the execution and any execution time
- * measurements.
+ * @return A tuple of:
+ * - status of the execution
+ * - dynamic output shapes from the execution
+ * - any execution time measurements of the execution
*/
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute(
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
+ // TODO: combine "compute" and "tryCompute" back into a single function.
+ // "tryCompute" was created later to return the "fallback" boolean. This
+ // could not be done directly in "compute" because the VTS test cases (which
+ // test burst using "compute") had already been locked down and could not be
+ // changed.
+ /**
+ * Execute a request on a model.
+ *
+ * @param request Arguments to be executed on a model.
+ * @param measure Whether to collect timing measurements, either YES or NO
+ * @param memoryIds Identifiers corresponding to each memory object in the
+ * request's pools.
+ * @return A tuple of:
+ * - status of the execution
+ * - dynamic output shapes from the execution
+ * - any execution time measurements of the execution
+ * - whether or not a failed burst execution should be re-run using a
+ * different path (e.g., IPreparedModel::executeSynchronously)
+ */
+ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> tryCompute(
+ const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
+
/**
* Propagate a user's freeing of memory to the service.
*
@@ -286,10 +323,11 @@ class ExecutionBurstController {
private:
std::mutex mMutex;
- const std::unique_ptr<RequestChannelSender> mRequestChannelSender;
- const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver;
+ const std::shared_ptr<RequestChannelSender> mRequestChannelSender;
+ const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver;
const sp<IBurstContext> mBurstContext;
const sp<ExecutionBurstCallback> mMemoryCache;
+ const sp<hardware::hidl_death_recipient> mDeathHandler;
};
} // namespace android::nn
diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp
index 8866a5a05..35b6788b4 100644
--- a/nn/runtime/ExecutionBuilder.cpp
+++ b/nn/runtime/ExecutionBuilder.cpp
@@ -898,7 +898,10 @@ int StepExecutor::startComputeOnDevice(
// in the design document.
sp<ExecutionCallback> executionCallback = new ExecutionCallback();
- if (burstController != nullptr) {
+ // compute using burst if present
+ const bool burstCompute = (burstController != nullptr);
+ bool burstFallback = false;
+ if (burstCompute) {
std::vector<intptr_t> memoryIds;
memoryIds.reserve(mMemories.size());
for (const Memory* memory : mMemories) {
@@ -906,34 +909,46 @@ int StepExecutor::startComputeOnDevice(
memoryIds.push_back(memory->getKey());
}
- VLOG(EXECUTION) << "Before ExecutionBurstController->compute() "
+ VLOG(EXECUTION) << "Before ExecutionBurstController->tryCompute() "
<< SHOW_IF_DEBUG(toString(request));
- auto burstExecuteResult =
- burstController->compute(request, measureTiming(mExecutionBuilder), memoryIds);
- executionCallback->notify(std::get<0>(burstExecuteResult), std::get<1>(burstExecuteResult),
- std::get<2>(burstExecuteResult));
- } else if (DeviceManager::get()->syncExecHal()) {
- VLOG(EXECUTION) << "Before mPreparedModel->executeSynchronously() "
- << SHOW_IF_DEBUG(toString(request));
- auto syncExecuteResult =
- mPreparedModel->executeSynchronously(request, measureTiming(mExecutionBuilder));
- executionCallback->notify(std::get<0>(syncExecuteResult), std::get<1>(syncExecuteResult),
- std::get<2>(syncExecuteResult));
- } else {
- VLOG(EXECUTION) << "Before mPreparedModel->execute() " << SHOW_IF_DEBUG(toString(request));
- // Execute.
- // TODO: What happens to the Callback if the service dies abnormally
- // -- won't that keep the Callback live forever, because the service
- // never has the opportunity to bump the reference count down? Or
- // maybe the HIDL infrastructure handles this magically? At worst,
- // it seems like this is a small memory leak, if the Callback stays
- // alive forever.
- Return<ErrorStatus> executeStatus = mPreparedModel->execute(
- request, measureTiming(mExecutionBuilder), executionCallback);
- if (!executeStatus.isOk() || executeStatus != ErrorStatus::NONE) {
- VLOG(EXECUTION) << "**Execute launch failed**";
- return executeStatus.isOk() ? convertErrorStatusToResultCode(executeStatus)
- : ANEURALNETWORKS_OP_FAILED;
+ auto [status, outputShapes, timing, fallback] =
+ burstController->tryCompute(request, measureTiming(mExecutionBuilder), memoryIds);
+
+ burstFallback = fallback;
+ if (!fallback) {
+ executionCallback->notify(status, outputShapes, timing);
+ }
+ }
+
+ // compute from IPreparedModel if either:
+ // (1) burst was not supplied, or
+ // (2) the burst execution failed and requested a fallback execution
+ if (!burstCompute || burstFallback) {
+ if (DeviceManager::get()->syncExecHal()) {
+ VLOG(EXECUTION) << "Before mPreparedModel->executeSynchronously() "
+ << SHOW_IF_DEBUG(toString(request));
+ auto syncExecuteResult =
+ mPreparedModel->executeSynchronously(request, measureTiming(mExecutionBuilder));
+ executionCallback->notify(std::get<0>(syncExecuteResult),
+ std::get<1>(syncExecuteResult),
+ std::get<2>(syncExecuteResult));
+ } else {
+ VLOG(EXECUTION) << "Before mPreparedModel->execute() "
+ << SHOW_IF_DEBUG(toString(request));
+ // Execute.
+ // TODO: What happens to the Callback if the service dies abnormally
+ // -- won't that keep the Callback live forever, because the service
+ // never has the opportunity to bump the reference count down? Or
+ // maybe the HIDL infrastructure handles this magically? At worst,
+ // it seems like this is a small memory leak, if the Callback stays
+ // alive forever.
+ Return<ErrorStatus> executeStatus = mPreparedModel->execute(
+ request, measureTiming(mExecutionBuilder), executionCallback);
+ if (!executeStatus.isOk() || executeStatus != ErrorStatus::NONE) {
+ VLOG(EXECUTION) << "**Execute launch failed**";
+ return executeStatus.isOk() ? convertErrorStatusToResultCode(executeStatus)
+ : ANEURALNETWORKS_OP_FAILED;
+ }
}
}