diff options
Diffstat (limited to 'nn/common/ExecutionBurstServer.cpp')
-rw-r--r-- | nn/common/ExecutionBurstServer.cpp | 137 |
1 files changed, 94 insertions, 43 deletions
diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp index 74bc34058..ec935dad6 100644 --- a/nn/common/ExecutionBurstServer.cpp +++ b/nn/common/ExecutionBurstServer.cpp @@ -20,9 +20,14 @@ #include <android-base/logging.h> +#include <algorithm> #include <cstring> #include <limits> #include <map> +#include <memory> +#include <tuple> +#include <utility> +#include <vector> #include "Tracing.h" @@ -31,6 +36,8 @@ namespace { using namespace hal; +using hardware::MQDescriptorSync; + constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(), std::numeric_limits<uint64_t>::max()}; @@ -298,20 +305,27 @@ std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserial // RequestChannelReceiver methods std::unique_ptr<RequestChannelReceiver> RequestChannelReceiver::create( - const FmqRequestDescriptor& requestChannel) { + const FmqRequestDescriptor& requestChannel, std::chrono::microseconds pollingTimeWindow) { std::unique_ptr<FmqRequestChannel> fmqRequestChannel = std::make_unique<FmqRequestChannel>(requestChannel); + if (!fmqRequestChannel->isValid()) { LOG(ERROR) << "Unable to create RequestChannelReceiver"; return nullptr; } - const bool blocking = fmqRequestChannel->getEventFlagWord() != nullptr; - return std::make_unique<RequestChannelReceiver>(std::move(fmqRequestChannel), blocking); + if (fmqRequestChannel->getEventFlagWord() == nullptr) { + LOG(ERROR) + << "RequestChannelReceiver::create was passed an MQDescriptor without an EventFlag"; + return nullptr; + } + + return std::make_unique<RequestChannelReceiver>(std::move(fmqRequestChannel), + pollingTimeWindow); } RequestChannelReceiver::RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, - bool blocking) - : mFmqRequestChannel(std::move(fmqRequestChannel)), mBlocking(blocking) {} + std::chrono::microseconds pollingTimeWindow) + : mFmqRequestChannel(std::move(fmqRequestChannel)), kPollingTimeWindow(pollingTimeWindow) {} std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> RequestChannelReceiver::getBlocking() { @@ -328,17 +342,15 @@ void RequestChannelReceiver::invalidate() { // force unblock // ExecutionBurstServer is by default waiting on a request packet. If the - // client process destroys its burst object, the server will still be - // waiting on the futex (assuming mBlocking is true). This force unblock - // wakes up any thread waiting on the futex. - if (mBlocking) { - // TODO: look for a different/better way to signal/notify the futex to - // wake up any thread waiting on it - FmqRequestDatum datum; - datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0, - /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0}); - mFmqRequestChannel->writeBlocking(&datum, 1); - } + // client process destroys its burst object, the server may still be waiting + // on the futex. This force unblock wakes up any thread waiting on the + // futex. + // TODO: look for a different/better way to signal/notify the futex to wake + // up any thread waiting on it + FmqRequestDatum datum; + datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0, + /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0}); + mFmqRequestChannel->writeBlocking(&datum, 1); } std::optional<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlocking() { @@ -348,17 +360,53 @@ std::optional<std::vector<FmqRequestDatum>> RequestChannelReceiver::getPacketBlo return std::nullopt; } - // wait for request packet and read first element of request packet - FmqRequestDatum datum; - bool success = false; - if (mBlocking) { - success = mFmqRequestChannel->readBlocking(&datum, 1); - } else { - while ((success = !mTeardown.load(std::memory_order_relaxed)) && - !mFmqRequestChannel->read(&datum, 1)) { + // First spend time polling if results are available in FMQ instead of + // waiting on the futex. Polling is more responsive (yielding lower + // latencies), but can take up more power, so only poll for a limited period + // of time. + + auto& getCurrentTime = std::chrono::high_resolution_clock::now; + const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow; + + while (getCurrentTime() < timeToStopPolling) { + // if class is being torn down, immediately return + if (mTeardown.load(std::memory_order_relaxed)) { + return std::nullopt; + } + + // Check if data is available. If it is, immediately retrieve it and + // return. + const size_t available = mFmqRequestChannel->availableToRead(); + if (available > 0) { + // This is the first point when we know an execution is occurring, + // so begin to collect systraces. Note that a similar systrace does + // not exist at the corresponding point in + // ResultChannelReceiver::getPacketBlocking because the execution is + // already in flight. + NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, + "ExecutionBurstServer getting packet"); + std::vector<FmqRequestDatum> packet(available); + const bool success = mFmqRequestChannel->read(packet.data(), available); + if (!success) { + LOG(ERROR) << "Error receiving packet"; + return std::nullopt; + } + return std::make_optional(std::move(packet)); } } + // If we get to this point, we either stopped polling because it was taking + // too long or polling was not allowed. Instead, perform a blocking call + // which uses a futex to save power. + + // wait for request packet and read first element of request packet + FmqRequestDatum datum; + bool success = mFmqRequestChannel->readBlocking(&datum, 1); + + // This is the first point when we know an execution is occurring, so begin + // to collect systraces. Note that a similar systrace does not exist at the + // corresponding point in ResultChannelReceiver::getPacketBlocking because + // the execution is already in flight. NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet"); // retrieve remaining elements @@ -393,17 +441,21 @@ std::unique_ptr<ResultChannelSender> ResultChannelSender::create( const FmqResultDescriptor& resultChannel) { std::unique_ptr<FmqResultChannel> fmqResultChannel = std::make_unique<FmqResultChannel>(resultChannel); + if (!fmqResultChannel->isValid()) { LOG(ERROR) << "Unable to create RequestChannelSender"; return nullptr; } - const bool blocking = fmqResultChannel->getEventFlagWord() != nullptr; - return std::make_unique<ResultChannelSender>(std::move(fmqResultChannel), blocking); + if (fmqResultChannel->getEventFlagWord() == nullptr) { + LOG(ERROR) << "ResultChannelSender::create was passed an MQDescriptor without an EventFlag"; + return nullptr; + } + + return std::make_unique<ResultChannelSender>(std::move(fmqResultChannel)); } -ResultChannelSender::ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel, - bool blocking) - : mFmqResultChannel(std::move(fmqResultChannel)), mBlocking(blocking) {} +ResultChannelSender::ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel) + : mFmqResultChannel(std::move(fmqResultChannel)) {} bool ResultChannelSender::send(ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing) { @@ -417,18 +469,15 @@ bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) << "ResultChannelSender::sendPacket -- packet size exceeds size available in FMQ"; const std::vector<FmqResultDatum> errorPacket = serialize(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming); - if (mBlocking) { - return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size()); - } else { - return mFmqResultChannel->write(errorPacket.data(), errorPacket.size()); - } - } - if (mBlocking) { - return mFmqResultChannel->writeBlocking(packet.data(), packet.size()); - } else { - return mFmqResultChannel->write(packet.data(), packet.size()); + // Always send the packet with "blocking" because this signals the futex + // and unblocks the consumer if it is waiting on the futex. + return mFmqResultChannel->writeBlocking(errorPacket.data(), errorPacket.size()); } + + // Always send the packet with "blocking" because this signals the futex and + // unblocks the consumer if it is waiting on the futex. + return mFmqResultChannel->writeBlocking(packet.data(), packet.size()); } // ExecutionBurstServer methods @@ -436,7 +485,8 @@ bool ResultChannelSender::sendPacket(const std::vector<FmqResultDatum>& packet) sp<ExecutionBurstServer> ExecutionBurstServer::create( const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel, const MQDescriptorSync<FmqResultDatum>& resultChannel, - std::shared_ptr<IBurstExecutorWithCache> executorWithCache) { + std::shared_ptr<IBurstExecutorWithCache> executorWithCache, + std::chrono::microseconds pollingTimeWindow) { // check inputs if (callback == nullptr || executorWithCache == nullptr) { LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr"; @@ -445,7 +495,7 @@ sp<ExecutionBurstServer> ExecutionBurstServer::create( // create FMQ objects std::unique_ptr<RequestChannelReceiver> requestChannelReceiver = - RequestChannelReceiver::create(requestChannel); + RequestChannelReceiver::create(requestChannel, pollingTimeWindow); std::unique_ptr<ResultChannelSender> resultChannelSender = ResultChannelSender::create(resultChannel); @@ -462,7 +512,8 @@ sp<ExecutionBurstServer> ExecutionBurstServer::create( sp<ExecutionBurstServer> ExecutionBurstServer::create( const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel, - const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) { + const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel, + std::chrono::microseconds pollingTimeWindow) { // check relevant input if (preparedModel == nullptr) { LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr"; @@ -475,7 +526,7 @@ sp<ExecutionBurstServer> ExecutionBurstServer::create( // make and return context return ExecutionBurstServer::create(callback, requestChannel, resultChannel, - preparedModelAdapter); + preparedModelAdapter, pollingTimeWindow); } ExecutionBurstServer::ExecutionBurstServer( |