diff options
author | Michael Butler <butlermichael@google.com> | 2019-05-10 18:50:02 -0700 |
---|---|---|
committer | Michael Butler <butlermichael@google.com> | 2019-05-14 16:01:34 -0700 |
commit | 75b49d08a640c10f7c70a391ea49d8c808485807 (patch) | |
tree | 676f876fee4e3bfe5431334106567a1ba1974e49 | |
parent | 9ec2e391f4fe045140b80608f79e1108e37d82f7 (diff) | |
download | ml-75b49d08a640c10f7c70a391ea49d8c808485807.tar.gz |
Protect asynchronous burst calls from hanging
Burst execution is asynchronous at the HAL layer: the runtime sends a
request packet across one FMQ, then waits for the response to be
received on the result FMQ. However, if the driver crashes after the
request request has been made but before the result has been received,
the runtime will hang. Specifically, the call to
ResultChannelReceiver::getPacketBlocking will never return.
This CL adds a death recipient to detect when the driver has crashed and
to unblock the runtime by returning a failure. The death recipient
additionally marks the sender and receiver objects as invalid, causing
any subsequent calls to send or receive a packet to immediately return a
failure in order to avoid future hangs.
This CL additionally returns a value to the runtime to indicate whether
the burst execution should be re-run using another execution path, such
as IPreparedModel::executeSynchronously* or IPreparedModel::execute.
ExecutionBurstController will request a re-run either when (1) the
request packet failed to send across the FMQ (e.g., when the number of
elements in the packet exceeded the size of the FMQ) or (2) when the
burst object has been marked as invalid.
Test: mma
Test: ran NeuralNetworksTest_static, made the sample driver's burst
execution artificially long, killed the sample driver, and
confirmed (1) the runtime recovered and (2) the appropriate log
messages appeared in logcat
Bug: 129157135
Bug: 131086786
Change-Id: I04fcb6247dc78ea057c7596682159af1f9025235
-rw-r--r-- | nn/common/ExecutionBurstController.cpp | 126 | ||||
-rw-r--r-- | nn/common/include/ExecutionBurstController.h | 54 | ||||
-rw-r--r-- | nn/runtime/ExecutionBuilder.cpp | 71 |
3 files changed, 187 insertions, 64 deletions
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp index be8be7105..55ef9ad29 100644 --- a/nn/common/ExecutionBurstController.cpp +++ b/nn/common/ExecutionBurstController.cpp @@ -34,6 +34,23 @@ using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>; constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(), std::numeric_limits<uint64_t>::max()}; +class BurstContextDeathHandler : public hardware::hidl_death_recipient { + public: + using Callback = std::function<void()>; + + BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) { + CHECK(onDeathCallback != nullptr); + } + + void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override { + LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!"; + mOnDeathCallback(); + } + + private: + const Callback mOnDeathCallback; +}; + } // anonymous namespace // serialize a request into a packet @@ -229,13 +246,13 @@ ResultChannelReceiver::getBlocking() { } void ResultChannelReceiver::invalidate() { - mTeardown = true; + mValid = false; // force unblock - // ExecutionBurstServer is by default waiting on a request packet. If the - // client process destroys its burst object, the server will still be - // waiting on the futex (assuming mBlocking is true). This force unblock - // wakes up any thread waiting on the futex. + // ExecutionBurstController waits on a result packet after sending a + // request. If the driver containing ExecutionBurstServer crashes, the + // controller will still be waiting on the futex (assuming mBlocking is + // true). This force unblock wakes up any thread waiting on the futex. if (mBlocking) { // TODO: look for a different/better way to signal/notify the futex to // wake up any thread waiting on it @@ -249,7 +266,7 @@ void ResultChannelReceiver::invalidate() { std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() { using discriminator = FmqResultDatum::hidl_discriminator; - if (mTeardown) { + if (!mValid) { return std::nullopt; } @@ -259,7 +276,7 @@ std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlock if (mBlocking) { success = mFmqResultChannel->readBlocking(&datum, 1); } else { - while ((success = !mTeardown.load(std::memory_order_relaxed)) && + while ((success = mValid.load(std::memory_order_relaxed)) && !mFmqResultChannel->read(&datum, 1)) { } } @@ -276,8 +293,7 @@ std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlock std::memcpy(&packet.front(), &datum, sizeof(datum)); success &= mFmqResultChannel->read(packet.data() + 1, count); - // terminate loop - if (mTeardown) { + if (!mValid) { return std::nullopt; } @@ -315,6 +331,10 @@ bool RequestChannelSender::send(const Request& request, MeasureTiming measure, } bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) { + if (!mValid) { + return false; + } + if (packet.size() > mFmqRequestChannel->availableToWrite()) { LOG(ERROR) << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ"; @@ -328,6 +348,10 @@ bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet } } +void RequestChannelSender::invalidate() { + mValid = false; +} + Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories( const hidl_vec<int32_t>& slots, getMemories_cb cb) { std::lock_guard<std::mutex> guard(mMutex); @@ -420,19 +444,20 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create( // create callback object sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback(); - if (callback == nullptr) { - LOG(ERROR) << "ExecutionBurstController::create failed to create callback"; - return nullptr; - } // create FMQ objects - auto [fmqRequestChannel, fmqRequestDescriptor] = + auto [requestChannelSenderTemp, requestChannelDescriptor] = RequestChannelSender::create(kExecutionBurstChannelLength, blocking); - auto [fmqResultChannel, fmqResultDescriptor] = + auto [resultChannelReceiverTemp, resultChannelDescriptor] = ResultChannelReceiver::create(kExecutionBurstChannelLength, blocking); + std::shared_ptr<RequestChannelSender> requestChannelSender = + std::move(requestChannelSenderTemp); + std::shared_ptr<ResultChannelReceiver> resultChannelReceiver = + std::move(resultChannelReceiverTemp); // check FMQ objects - if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestDescriptor || !fmqResultDescriptor) { + if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor || + !resultChannelDescriptor) { LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue"; return nullptr; } @@ -441,7 +466,7 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create( ErrorStatus errorStatus; sp<IBurstContext> burstContext; const Return<void> ret = preparedModel->configureExecutionBurst( - callback, *fmqRequestDescriptor, *fmqResultDescriptor, + callback, *requestChannelDescriptor, *resultChannelDescriptor, [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) { errorStatus = status; burstContext = context; @@ -463,22 +488,61 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create( return nullptr; } + // create death handler object + BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender, + resultChannelReceiver] { + requestChannelSender->invalidate(); + resultChannelReceiver->invalidate(); + }; + const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback); + + // linkToDeath registers a callback that will be invoked on service death to + // proactively handle service crashes. If the linkToDeath call fails, + // asynchronous calls are susceptible to hangs if the service crashes before + // providing the response. + const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0); + if (!deathHandlerRet.isOk() || deathHandlerRet != true) { + LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient " + "for the IBurstContext object."; + return nullptr; + } + // make and return controller - return std::make_unique<ExecutionBurstController>( - std::move(fmqRequestChannel), std::move(fmqResultChannel), burstContext, callback); + return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver, + burstContext, callback, deathHandler); } ExecutionBurstController::ExecutionBurstController( - std::unique_ptr<RequestChannelSender> requestChannelSender, - std::unique_ptr<ResultChannelReceiver> resultChannelReceiver, - const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback) - : mRequestChannelSender(std::move(requestChannelSender)), - mResultChannelReceiver(std::move(resultChannelReceiver)), + const std::shared_ptr<RequestChannelSender>& requestChannelSender, + const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver, + const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback, + const sp<hardware::hidl_death_recipient>& deathHandler) + : mRequestChannelSender(requestChannelSender), + mResultChannelReceiver(resultChannelReceiver), mBurstContext(burstContext), - mMemoryCache(callback) {} + mMemoryCache(callback), + mDeathHandler(deathHandler) {} + +ExecutionBurstController::~ExecutionBurstController() { + // It is safe to ignore any errors resulting from this unlinkToDeath call + // because the ExecutionBurstController object is already being destroyed + // and its underlying IBurstContext object is no longer being used by the NN + // runtime. + if (mDeathHandler) { + mBurstContext->unlinkToDeath(mDeathHandler).isOk(); + } +} std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute( const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) { + auto [status, outputShapes, timing, fallback] = tryCompute(request, measure, memoryIds); + (void)fallback; // ignore fallback field + return {status, std::move(outputShapes), timing}; +} + +std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> +ExecutionBurstController::tryCompute(const Request& request, MeasureTiming measure, + const std::vector<intptr_t>& memoryIds) { NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute"); std::lock_guard<std::mutex> guard(mMutex); @@ -488,16 +552,22 @@ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstControll const bool success = mRequestChannelSender->send(request, measure, slots); if (!success) { LOG(ERROR) << "Error sending FMQ packet"; - return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming}; + // only use fallback execution path if the packet could not be sent + return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/true}; } // get result packet const auto result = mResultChannelReceiver->getBlocking(); if (!result) { LOG(ERROR) << "Error retrieving FMQ packet"; - return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming}; + // only use fallback execution path if the packet could not be sent + return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/false}; } - return *result; + + // unpack results and return (only use fallback execution path if the + // packet could not be sent) + auto [status, outputShapes, timing] = std::move(*result); + return {status, std::move(outputShapes), timing, /*fallback=*/false}; } void ExecutionBurstController::freeMemory(intptr_t key) { diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h index 3397d9673..030ea4d2d 100644 --- a/nn/common/include/ExecutionBurstController.h +++ b/nn/common/include/ExecutionBurstController.h @@ -117,7 +117,7 @@ class ResultChannelReceiver { private: const std::unique_ptr<FmqResultChannel> mFmqResultChannel; - std::atomic<bool> mTeardown{false}; + std::atomic<bool> mValid{true}; const bool mBlocking; }; @@ -157,6 +157,13 @@ class RequestChannelSender { */ bool send(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots); + /** + * Method to mark the channel as invalid, causing all future calls to + * RequestChannelSender::send to immediately return false without attempting + * to send a message across the FMQ. + */ + void invalidate(); + // prefer calling RequestChannelSender::send bool sendPacket(const std::vector<FmqRequestDatum>& packet); @@ -164,6 +171,7 @@ class RequestChannelSender { private: const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel; + std::atomic<bool> mValid{true}; const bool mBlocking; }; @@ -259,10 +267,15 @@ class ExecutionBurstController { static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel, bool blocking); - ExecutionBurstController(std::unique_ptr<RequestChannelSender> requestChannelSender, - std::unique_ptr<ResultChannelReceiver> resultChannelReceiver, + // prefer calling ExecutionBurstController::create + ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender, + const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver, const sp<IBurstContext>& burstContext, - const sp<ExecutionBurstCallback>& callback); + const sp<ExecutionBurstCallback>& callback, + const sp<hardware::hidl_death_recipient>& deathHandler = nullptr); + + // explicit destructor to unregister the death recipient + ~ExecutionBurstController(); /** * Execute a request on a model. @@ -271,12 +284,36 @@ class ExecutionBurstController { * @param measure Whether to collect timing measurements, either YES or NO * @param memoryIds Identifiers corresponding to each memory object in the * request's pools. - * @return status and output shape of the execution and any execution time - * measurements. + * @return A tuple of: + * - status of the execution + * - dynamic output shapes from the execution + * - any execution time measurements of the execution */ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute( const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds); + // TODO: combine "compute" and "tryCompute" back into a single function. + // "tryCompute" was created later to return the "fallback" boolean. This + // could not be done directly in "compute" because the VTS test cases (which + // test burst using "compute") had already been locked down and could not be + // changed. + /** + * Execute a request on a model. + * + * @param request Arguments to be executed on a model. + * @param measure Whether to collect timing measurements, either YES or NO + * @param memoryIds Identifiers corresponding to each memory object in the + * request's pools. + * @return A tuple of: + * - status of the execution + * - dynamic output shapes from the execution + * - any execution time measurements of the execution + * - whether or not a failed burst execution should be re-run using a + * different path (e.g., IPreparedModel::executeSynchronously) + */ + std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> tryCompute( + const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds); + /** * Propagate a user's freeing of memory to the service. * @@ -286,10 +323,11 @@ class ExecutionBurstController { private: std::mutex mMutex; - const std::unique_ptr<RequestChannelSender> mRequestChannelSender; - const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver; + const std::shared_ptr<RequestChannelSender> mRequestChannelSender; + const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver; const sp<IBurstContext> mBurstContext; const sp<ExecutionBurstCallback> mMemoryCache; + const sp<hardware::hidl_death_recipient> mDeathHandler; }; } // namespace android::nn diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp index 8866a5a05..35b6788b4 100644 --- a/nn/runtime/ExecutionBuilder.cpp +++ b/nn/runtime/ExecutionBuilder.cpp @@ -898,7 +898,10 @@ int StepExecutor::startComputeOnDevice( // in the design document. sp<ExecutionCallback> executionCallback = new ExecutionCallback(); - if (burstController != nullptr) { + // compute using burst if present + const bool burstCompute = (burstController != nullptr); + bool burstFallback = false; + if (burstCompute) { std::vector<intptr_t> memoryIds; memoryIds.reserve(mMemories.size()); for (const Memory* memory : mMemories) { @@ -906,34 +909,46 @@ int StepExecutor::startComputeOnDevice( memoryIds.push_back(memory->getKey()); } - VLOG(EXECUTION) << "Before ExecutionBurstController->compute() " + VLOG(EXECUTION) << "Before ExecutionBurstController->tryCompute() " << SHOW_IF_DEBUG(toString(request)); - auto burstExecuteResult = - burstController->compute(request, measureTiming(mExecutionBuilder), memoryIds); - executionCallback->notify(std::get<0>(burstExecuteResult), std::get<1>(burstExecuteResult), - std::get<2>(burstExecuteResult)); - } else if (DeviceManager::get()->syncExecHal()) { - VLOG(EXECUTION) << "Before mPreparedModel->executeSynchronously() " - << SHOW_IF_DEBUG(toString(request)); - auto syncExecuteResult = - mPreparedModel->executeSynchronously(request, measureTiming(mExecutionBuilder)); - executionCallback->notify(std::get<0>(syncExecuteResult), std::get<1>(syncExecuteResult), - std::get<2>(syncExecuteResult)); - } else { - VLOG(EXECUTION) << "Before mPreparedModel->execute() " << SHOW_IF_DEBUG(toString(request)); - // Execute. - // TODO: What happens to the Callback if the service dies abnormally - // -- won't that keep the Callback live forever, because the service - // never has the opportunity to bump the reference count down? Or - // maybe the HIDL infrastructure handles this magically? At worst, - // it seems like this is a small memory leak, if the Callback stays - // alive forever. - Return<ErrorStatus> executeStatus = mPreparedModel->execute( - request, measureTiming(mExecutionBuilder), executionCallback); - if (!executeStatus.isOk() || executeStatus != ErrorStatus::NONE) { - VLOG(EXECUTION) << "**Execute launch failed**"; - return executeStatus.isOk() ? convertErrorStatusToResultCode(executeStatus) - : ANEURALNETWORKS_OP_FAILED; + auto [status, outputShapes, timing, fallback] = + burstController->tryCompute(request, measureTiming(mExecutionBuilder), memoryIds); + + burstFallback = fallback; + if (!fallback) { + executionCallback->notify(status, outputShapes, timing); + } + } + + // compute from IPreparedModel if either: + // (1) burst was not supplied, or + // (2) the burst execution failed and requested a fallback execution + if (!burstCompute || burstFallback) { + if (DeviceManager::get()->syncExecHal()) { + VLOG(EXECUTION) << "Before mPreparedModel->executeSynchronously() " + << SHOW_IF_DEBUG(toString(request)); + auto syncExecuteResult = + mPreparedModel->executeSynchronously(request, measureTiming(mExecutionBuilder)); + executionCallback->notify(std::get<0>(syncExecuteResult), + std::get<1>(syncExecuteResult), + std::get<2>(syncExecuteResult)); + } else { + VLOG(EXECUTION) << "Before mPreparedModel->execute() " + << SHOW_IF_DEBUG(toString(request)); + // Execute. + // TODO: What happens to the Callback if the service dies abnormally + // -- won't that keep the Callback live forever, because the service + // never has the opportunity to bump the reference count down? Or + // maybe the HIDL infrastructure handles this magically? At worst, + // it seems like this is a small memory leak, if the Callback stays + // alive forever. + Return<ErrorStatus> executeStatus = mPreparedModel->execute( + request, measureTiming(mExecutionBuilder), executionCallback); + if (!executeStatus.isOk() || executeStatus != ErrorStatus::NONE) { + VLOG(EXECUTION) << "**Execute launch failed**"; + return executeStatus.isOk() ? convertErrorStatusToResultCode(executeStatus) + : ANEURALNETWORKS_OP_FAILED; + } } } |