diff options
author | android-build-team Robot <android-build-team-robot@google.com> | 2019-05-16 03:14:26 +0000 |
---|---|---|
committer | android-build-team Robot <android-build-team-robot@google.com> | 2019-05-16 03:14:26 +0000 |
commit | 2af3599f21dda8dcd4aa1743bf95e17f708c2b67 (patch) | |
tree | ea9e0e63cadd83f9677c64801622c998f3e2538e | |
parent | c26781743003e5b928cf6b6b57217ec1851bd0e2 (diff) | |
parent | 4363637cbd8ca9b782a9ab5e2bfdc95404649910 (diff) | |
download | ml-2af3599f21dda8dcd4aa1743bf95e17f708c2b67.tar.gz |
Snap for 5571215 from 4363637cbd8ca9b782a9ab5e2bfdc95404649910 to qt-release
Change-Id: I33bd24dfa13da23692ccd985e2ee2fd990953c88
-rw-r--r-- | nn/common/ExecutionBurstController.cpp | 126 | ||||
-rw-r--r-- | nn/common/Utils.cpp | 34 | ||||
-rw-r--r-- | nn/common/ValidateHal.cpp | 93 | ||||
-rw-r--r-- | nn/common/include/ExecutionBurstController.h | 54 | ||||
-rw-r--r-- | nn/common/include/Utils.h | 2 | ||||
-rw-r--r-- | nn/common/include/ValidateHal.h | 13 | ||||
-rw-r--r-- | nn/common/operations/Activation.cpp | 43 | ||||
-rw-r--r-- | nn/common/operations/RoiPooling.cpp | 3 | ||||
-rw-r--r-- | nn/runtime/ExecutionBuilder.cpp | 71 | ||||
-rw-r--r-- | nn/runtime/test/TestCompliance.cpp | 88 | ||||
-rw-r--r-- | nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp | 4 | ||||
-rw-r--r-- | nn/runtime/test/generated/models/logistic_v1_2.model.cpp | 4 | ||||
-rw-r--r-- | nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp | 4 | ||||
-rw-r--r-- | nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py | 2 |
14 files changed, 407 insertions, 134 deletions
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp index be8be7105..55ef9ad29 100644 --- a/nn/common/ExecutionBurstController.cpp +++ b/nn/common/ExecutionBurstController.cpp @@ -34,6 +34,23 @@ using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>; constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(), std::numeric_limits<uint64_t>::max()}; +class BurstContextDeathHandler : public hardware::hidl_death_recipient { + public: + using Callback = std::function<void()>; + + BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) { + CHECK(onDeathCallback != nullptr); + } + + void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override { + LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!"; + mOnDeathCallback(); + } + + private: + const Callback mOnDeathCallback; +}; + } // anonymous namespace // serialize a request into a packet @@ -229,13 +246,13 @@ ResultChannelReceiver::getBlocking() { } void ResultChannelReceiver::invalidate() { - mTeardown = true; + mValid = false; // force unblock - // ExecutionBurstServer is by default waiting on a request packet. If the - // client process destroys its burst object, the server will still be - // waiting on the futex (assuming mBlocking is true). This force unblock - // wakes up any thread waiting on the futex. + // ExecutionBurstController waits on a result packet after sending a + // request. If the driver containing ExecutionBurstServer crashes, the + // controller will still be waiting on the futex (assuming mBlocking is + // true). This force unblock wakes up any thread waiting on the futex. if (mBlocking) { // TODO: look for a different/better way to signal/notify the futex to // wake up any thread waiting on it @@ -249,7 +266,7 @@ void ResultChannelReceiver::invalidate() { std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() { using discriminator = FmqResultDatum::hidl_discriminator; - if (mTeardown) { + if (!mValid) { return std::nullopt; } @@ -259,7 +276,7 @@ std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlock if (mBlocking) { success = mFmqResultChannel->readBlocking(&datum, 1); } else { - while ((success = !mTeardown.load(std::memory_order_relaxed)) && + while ((success = mValid.load(std::memory_order_relaxed)) && !mFmqResultChannel->read(&datum, 1)) { } } @@ -276,8 +293,7 @@ std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlock std::memcpy(&packet.front(), &datum, sizeof(datum)); success &= mFmqResultChannel->read(packet.data() + 1, count); - // terminate loop - if (mTeardown) { + if (!mValid) { return std::nullopt; } @@ -315,6 +331,10 @@ bool RequestChannelSender::send(const Request& request, MeasureTiming measure, } bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) { + if (!mValid) { + return false; + } + if (packet.size() > mFmqRequestChannel->availableToWrite()) { LOG(ERROR) << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ"; @@ -328,6 +348,10 @@ bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet } } +void RequestChannelSender::invalidate() { + mValid = false; +} + Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories( const hidl_vec<int32_t>& slots, getMemories_cb cb) { std::lock_guard<std::mutex> guard(mMutex); @@ -420,19 +444,20 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create( // create callback object sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback(); - if (callback == nullptr) { - LOG(ERROR) << "ExecutionBurstController::create failed to create callback"; - return nullptr; - } // create FMQ objects - auto [fmqRequestChannel, fmqRequestDescriptor] = + auto [requestChannelSenderTemp, requestChannelDescriptor] = RequestChannelSender::create(kExecutionBurstChannelLength, blocking); - auto [fmqResultChannel, fmqResultDescriptor] = + auto [resultChannelReceiverTemp, resultChannelDescriptor] = ResultChannelReceiver::create(kExecutionBurstChannelLength, blocking); + std::shared_ptr<RequestChannelSender> requestChannelSender = + std::move(requestChannelSenderTemp); + std::shared_ptr<ResultChannelReceiver> resultChannelReceiver = + std::move(resultChannelReceiverTemp); // check FMQ objects - if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestDescriptor || !fmqResultDescriptor) { + if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor || + !resultChannelDescriptor) { LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue"; return nullptr; } @@ -441,7 +466,7 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create( ErrorStatus errorStatus; sp<IBurstContext> burstContext; const Return<void> ret = preparedModel->configureExecutionBurst( - callback, *fmqRequestDescriptor, *fmqResultDescriptor, + callback, *requestChannelDescriptor, *resultChannelDescriptor, [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) { errorStatus = status; burstContext = context; @@ -463,22 +488,61 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create( return nullptr; } + // create death handler object + BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender, + resultChannelReceiver] { + requestChannelSender->invalidate(); + resultChannelReceiver->invalidate(); + }; + const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback); + + // linkToDeath registers a callback that will be invoked on service death to + // proactively handle service crashes. If the linkToDeath call fails, + // asynchronous calls are susceptible to hangs if the service crashes before + // providing the response. + const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0); + if (!deathHandlerRet.isOk() || deathHandlerRet != true) { + LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient " + "for the IBurstContext object."; + return nullptr; + } + // make and return controller - return std::make_unique<ExecutionBurstController>( - std::move(fmqRequestChannel), std::move(fmqResultChannel), burstContext, callback); + return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver, + burstContext, callback, deathHandler); } ExecutionBurstController::ExecutionBurstController( - std::unique_ptr<RequestChannelSender> requestChannelSender, - std::unique_ptr<ResultChannelReceiver> resultChannelReceiver, - const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback) - : mRequestChannelSender(std::move(requestChannelSender)), - mResultChannelReceiver(std::move(resultChannelReceiver)), + const std::shared_ptr<RequestChannelSender>& requestChannelSender, + const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver, + const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback, + const sp<hardware::hidl_death_recipient>& deathHandler) + : mRequestChannelSender(requestChannelSender), + mResultChannelReceiver(resultChannelReceiver), mBurstContext(burstContext), - mMemoryCache(callback) {} + mMemoryCache(callback), + mDeathHandler(deathHandler) {} + +ExecutionBurstController::~ExecutionBurstController() { + // It is safe to ignore any errors resulting from this unlinkToDeath call + // because the ExecutionBurstController object is already being destroyed + // and its underlying IBurstContext object is no longer being used by the NN + // runtime. + if (mDeathHandler) { + mBurstContext->unlinkToDeath(mDeathHandler).isOk(); + } +} std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute( const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) { + auto [status, outputShapes, timing, fallback] = tryCompute(request, measure, memoryIds); + (void)fallback; // ignore fallback field + return {status, std::move(outputShapes), timing}; +} + +std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> +ExecutionBurstController::tryCompute(const Request& request, MeasureTiming measure, + const std::vector<intptr_t>& memoryIds) { NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute"); std::lock_guard<std::mutex> guard(mMutex); @@ -488,16 +552,22 @@ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstControll const bool success = mRequestChannelSender->send(request, measure, slots); if (!success) { LOG(ERROR) << "Error sending FMQ packet"; - return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming}; + // only use fallback execution path if the packet could not be sent + return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/true}; } // get result packet const auto result = mResultChannelReceiver->getBlocking(); if (!result) { LOG(ERROR) << "Error retrieving FMQ packet"; - return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming}; + // only use fallback execution path if the packet could not be sent + return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/false}; } - return *result; + + // unpack results and return (only use fallback execution path if the + // packet could not be sent) + auto [status, outputShapes, timing] = std::move(*result); + return {status, std::move(outputShapes), timing, /*fallback=*/false}; } void ExecutionBurstController::freeMemory(intptr_t key) { diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index a1b3b6a7e..731127a26 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -2076,6 +2076,12 @@ static hidl_vec<V1_1::Operation> convertToV1_1(const hidl_vec<V1_0::Operation>& return result; } +bool compliantWithV1_0(const V1_2::Operand& operand) { + return validOperandType(static_cast<V1_0::OperandType>(operand.type)) && + (nonExtensionOperandTypeIsScalar(static_cast<int>(operand.type)) || + operand.dimensions.size() != 0); +} + V1_0::Model convertToV1_0(const V1_0::Model& model) { return model; } @@ -2119,7 +2125,33 @@ void logModelToInfo(const V1_2::Model& model) { static bool compliantWith(HalVersion version, const V1_2::Model& model, std::set<uint32_t>* noncompliantOperations) { - auto localValidateOperation = [&model, version](const V1_2::Operation& op) { + if (version >= HalVersion::V1_2) return true; + + // A boolean vector indicating whether each pool is compliant with the target HAL version. + std::vector<bool> isPoolCompliant(model.pools.size(), false); + std::transform(model.pools.begin(), model.pools.end(), isPoolCompliant.begin(), + [version](const hidl_memory& pool) { return validatePool(pool, version); }); + + // A boolean vector indicating whether each operand is compliant with the target HAL version. + std::vector<bool> isOperandCompliant(model.operands.size(), false); + std::transform(model.operands.begin(), model.operands.end(), isOperandCompliant.begin(), + [&isPoolCompliant](const V1_2::Operand& op) { + // There is no V1_1::Operand -- both V1_0::Model and V1_1::Model use + // V1_0::Operand. + return compliantWithV1_0(op) && + !(op.lifetime == OperandLifeTime::CONSTANT_REFERENCE && + !isPoolCompliant[op.location.poolIndex]); + }); + + auto allOperandsCompliant = [&isOperandCompliant](const hidl_vec<uint32_t>& indices) { + return std::all_of( + indices.begin(), indices.end(), + [&isOperandCompliant](const uint32_t ind) { return isOperandCompliant[ind]; }); + }; + + auto localValidateOperation = [&model, version, + &allOperandsCompliant](const V1_2::Operation& op) { + if (!allOperandsCompliant(op.inputs) || !allOperandsCompliant(op.outputs)) return false; int error = validateOperation( static_cast<int32_t>(op.type), op.inputs.size(), op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(), diff --git a/nn/common/ValidateHal.cpp b/nn/common/ValidateHal.cpp index 1015d922d..421730a1d 100644 --- a/nn/common/ValidateHal.cpp +++ b/nn/common/ValidateHal.cpp @@ -27,6 +27,21 @@ namespace android { namespace nn { +template <class T_Model> +struct ModelToHalVersion; +template <> +struct ModelToHalVersion<V1_0::Model> { + static constexpr HalVersion version = HalVersion::V1_0; +}; +template <> +struct ModelToHalVersion<V1_1::Model> { + static constexpr HalVersion version = HalVersion::V1_1; +}; +template <> +struct ModelToHalVersion<V1_2::Model> { + static constexpr HalVersion version = HalVersion::V1_2; +}; + class MemoryAccessVerifier { public: MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools) @@ -418,22 +433,26 @@ static bool validateOperations(const hidl_vec<VersionedOperation>& operations, return true; } -static bool validatePools(const hidl_vec<hidl_memory>& pools) { - for (const hidl_memory& memory : pools) { - const auto& name = memory.name(); - if (name != "ashmem" && name != "mmap_fd" && name != "hardware_buffer_blob" && - name != "hardware_buffer") { - LOG(ERROR) << "Unsupported memory type " << name; - return false; - } - if (memory.handle() == nullptr) { - LOG(ERROR) << "Memory of type " << name << " is null"; - return false; - } +bool validatePool(const hidl_memory& pool, HalVersion ver) { + const auto& name = pool.name(); + if (name != "ashmem" && name != "mmap_fd" && + ((ver < HalVersion::V1_2) || + (name != "hardware_buffer_blob" && name != "hardware_buffer"))) { + LOG(ERROR) << "Unsupported memory type " << name; + return false; + } + if (pool.handle() == nullptr) { + LOG(ERROR) << "Memory of type " << name << " is null"; + return false; } return true; } +static bool validatePools(const hidl_vec<hidl_memory>& pools, HalVersion ver) { + return std::all_of(pools.begin(), pools.end(), + [ver](const hidl_memory& pool) { return validatePool(pool, ver); }); +} + static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes, const hidl_vec<Operand>& operands, OperandLifeTime lifetime) { const size_t operandCount = operands.size(); @@ -460,10 +479,10 @@ static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes, return true; } -template <typename VersionedModel> -static bool validateModelVersioned(const VersionedModel& model, bool allowUnspecifiedRank) { - NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, - "validateModelVersioned"); +template <class T_Model> +bool validateModel(const T_Model& model) { + NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel"); + HalVersion version = ModelToHalVersion<T_Model>::version; if (model.operations.size() == 0 || model.operands.size() == 0) { LOG(ERROR) << "Invalid empty model."; return false; @@ -472,26 +491,18 @@ static bool validateModelVersioned(const VersionedModel& model, bool allowUnspec // validations we can use operands upcasted to the latest version. const hidl_vec<Operand> latestVersionOperands = convertToV1_2(model.operands); return (validateOperands(model.operands, model.operandValues, model.pools, - allowUnspecifiedRank) && + /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) && validateOperations(model.operations, latestVersionOperands) && validateModelInputOutputs(model.inputIndexes, latestVersionOperands, OperandLifeTime::MODEL_INPUT) && validateModelInputOutputs(model.outputIndexes, latestVersionOperands, OperandLifeTime::MODEL_OUTPUT) && - validatePools(model.pools)); -} - -bool validateModel(const V1_0::Model& model) { - return validateModelVersioned(model, /*allowUnspecifiedRank=*/false); + validatePools(model.pools, version)); } -bool validateModel(const V1_1::Model& model) { - return validateModelVersioned(model, /*allowUnspecifiedRank=*/false); -} - -bool validateModel(const V1_2::Model& model) { - return validateModelVersioned(model, /*allowUnspecifiedRank=*/true); -} +template bool validateModel<V1_0::Model>(const V1_0::Model& model); +template bool validateModel<V1_1::Model>(const V1_1::Model& model); +template bool validateModel<V1_2::Model>(const V1_2::Model& model); // Validates the arguments of a request. type is either "input" or "output" and is used // for printing error messages. The operandIndexes is the appropriate array of input @@ -572,29 +583,21 @@ static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArg return true; } -template <typename VersionedModel> -static bool validateRequestVersioned(const Request& request, const VersionedModel& model, - bool allowDynamicOutputShape) { +template <class T_Model> +bool validateRequest(const Request& request, const T_Model& model) { + HalVersion version = ModelToHalVersion<T_Model>::version; return (validateRequestArguments(request.inputs, model.inputIndexes, convertToV1_2(model.operands), request.pools, /*allowUnspecified=*/false, "input") && validateRequestArguments(request.outputs, model.outputIndexes, convertToV1_2(model.operands), request.pools, - /*allowUnspecified=*/allowDynamicOutputShape, "output") && - validatePools(request.pools)); + /*allowUnspecified=*/version >= HalVersion::V1_2, "output") && + validatePools(request.pools, version)); } -bool validateRequest(const Request& request, const V1_0::Model& model) { - return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/false); -} - -bool validateRequest(const Request& request, const V1_1::Model& model) { - return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/false); -} - -bool validateRequest(const Request& request, const V1_2::Model& model) { - return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/true); -} +template bool validateRequest<V1_0::Model>(const Request& request, const V1_0::Model& model); +template bool validateRequest<V1_1::Model>(const Request& request, const V1_1::Model& model); +template bool validateRequest<V1_2::Model>(const Request& request, const V1_2::Model& model); bool validateExecutionPreference(ExecutionPreference preference) { return preference == ExecutionPreference::LOW_POWER || diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h index 3397d9673..030ea4d2d 100644 --- a/nn/common/include/ExecutionBurstController.h +++ b/nn/common/include/ExecutionBurstController.h @@ -117,7 +117,7 @@ class ResultChannelReceiver { private: const std::unique_ptr<FmqResultChannel> mFmqResultChannel; - std::atomic<bool> mTeardown{false}; + std::atomic<bool> mValid{true}; const bool mBlocking; }; @@ -157,6 +157,13 @@ class RequestChannelSender { */ bool send(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots); + /** + * Method to mark the channel as invalid, causing all future calls to + * RequestChannelSender::send to immediately return false without attempting + * to send a message across the FMQ. + */ + void invalidate(); + // prefer calling RequestChannelSender::send bool sendPacket(const std::vector<FmqRequestDatum>& packet); @@ -164,6 +171,7 @@ class RequestChannelSender { private: const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel; + std::atomic<bool> mValid{true}; const bool mBlocking; }; @@ -259,10 +267,15 @@ class ExecutionBurstController { static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel, bool blocking); - ExecutionBurstController(std::unique_ptr<RequestChannelSender> requestChannelSender, - std::unique_ptr<ResultChannelReceiver> resultChannelReceiver, + // prefer calling ExecutionBurstController::create + ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender, + const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver, const sp<IBurstContext>& burstContext, - const sp<ExecutionBurstCallback>& callback); + const sp<ExecutionBurstCallback>& callback, + const sp<hardware::hidl_death_recipient>& deathHandler = nullptr); + + // explicit destructor to unregister the death recipient + ~ExecutionBurstController(); /** * Execute a request on a model. @@ -271,12 +284,36 @@ class ExecutionBurstController { * @param measure Whether to collect timing measurements, either YES or NO * @param memoryIds Identifiers corresponding to each memory object in the * request's pools. - * @return status and output shape of the execution and any execution time - * measurements. + * @return A tuple of: + * - status of the execution + * - dynamic output shapes from the execution + * - any execution time measurements of the execution */ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute( const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds); + // TODO: combine "compute" and "tryCompute" back into a single function. + // "tryCompute" was created later to return the "fallback" boolean. This + // could not be done directly in "compute" because the VTS test cases (which + // test burst using "compute") had already been locked down and could not be + // changed. + /** + * Execute a request on a model. + * + * @param request Arguments to be executed on a model. + * @param measure Whether to collect timing measurements, either YES or NO + * @param memoryIds Identifiers corresponding to each memory object in the + * request's pools. + * @return A tuple of: + * - status of the execution + * - dynamic output shapes from the execution + * - any execution time measurements of the execution + * - whether or not a failed burst execution should be re-run using a + * different path (e.g., IPreparedModel::executeSynchronously) + */ + std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> tryCompute( + const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds); + /** * Propagate a user's freeing of memory to the service. * @@ -286,10 +323,11 @@ class ExecutionBurstController { private: std::mutex mMutex; - const std::unique_ptr<RequestChannelSender> mRequestChannelSender; - const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver; + const std::shared_ptr<RequestChannelSender> mRequestChannelSender; + const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver; const sp<IBurstContext> mBurstContext; const sp<ExecutionBurstCallback> mMemoryCache; + const sp<hardware::hidl_death_recipient> mDeathHandler; }; } // namespace android::nn diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h index 64472f12a..bf6cffda4 100644 --- a/nn/common/include/Utils.h +++ b/nn/common/include/Utils.h @@ -315,6 +315,8 @@ bool compliantWithV1_2(const V1_0::Capabilities& capabilities); bool compliantWithV1_2(const V1_1::Capabilities& capabilities); bool compliantWithV1_2(const V1_2::Capabilities& capabilities); +bool compliantWithV1_0(const V1_2::Operand& operand); + // If noncompliantOperations != nullptr, then // precondition: noncompliantOperations->empty() // postcondition: *noncompliantOperations consists of the indices of the noncompliant diff --git a/nn/common/include/ValidateHal.h b/nn/common/include/ValidateHal.h index c953d8a10..4275a24a7 100644 --- a/nn/common/include/ValidateHal.h +++ b/nn/common/include/ValidateHal.h @@ -36,17 +36,15 @@ enum class HalVersion : int32_t { // IMPORTANT: This function cannot validate that OEM operation and operands // are correctly defined, as these are specific to each implementation. // Each driver should do their own validation of OEM types. -bool validateModel(const V1_0::Model& model); -bool validateModel(const V1_1::Model& model); -bool validateModel(const V1_2::Model& model); +template <class T_Model> +bool validateModel(const T_Model& model); // Verfies that the request for the given model is valid. // IMPORTANT: This function cannot validate that OEM operation and operands // are correctly defined, as these are specific to each implementation. // Each driver should do their own validation of OEM types. -bool validateRequest(const Request& request, const V1_0::Model& model); -bool validateRequest(const Request& request, const V1_1::Model& model); -bool validateRequest(const Request& request, const V1_2::Model& model); +template <class T_Model> +bool validateRequest(const Request& request, const T_Model& model); // Verfies that the execution preference is valid. bool validateExecutionPreference(ExecutionPreference preference); @@ -58,6 +56,9 @@ bool validOperationType(V1_2::OperationType operation); bool validOperandType(V1_0::OperandType operand); bool validOperandType(V1_2::OperandType operand); +// Verfies that the memory pool is valid in the specified HAL version. +bool validatePool(const hidl_memory& pool, HalVersion ver = HalVersion::LATEST); + } // namespace nn } // namespace android diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp index 491226e60..f85f6b4bf 100644 --- a/nn/common/operations/Activation.cpp +++ b/nn/common/operations/Activation.cpp @@ -226,11 +226,28 @@ bool validate(OperationType opType, const IOperationValidationContext* context) return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); } -bool prepare(IOperationExecutionContext* context) { +bool prepare(OperationType opType, IOperationExecutionContext* context) { Shape input = context->getInputShape(kInputTensor); NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); - Shape output = context->getOutputShape(kOutputTensor); - output.dimensions = input.dimensions; + Shape output = input; + if (input.type == OperandType::TENSOR_QUANT8_ASYMM) { + switch (opType) { + case OperationType::RELU: + case OperationType::RELU1: + case OperationType::RELU6: + break; + case OperationType::LOGISTIC: + output.scale = 1.f / 256; + output.offset = 0; + break; + case OperationType::TANH: + output.scale = 1.f / 128; + output.offset = 128; + break; + default: + NN_RET_CHECK_FAIL() << "Unsupported operation type"; + } + } return context->setOutputShape(kOutputTensor, output); } @@ -326,7 +343,7 @@ bool executeLogistic(IOperationExecutionContext* context) { context->getOutputBuffer<uint8_t>(kOutputTensor), context->getOutputShape(kOutputTensor)); default: - NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH"; + NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC"; } } @@ -350,7 +367,7 @@ bool executeTanh(IOperationExecutionContext* context) { context->getOutputBuffer<uint8_t>(kOutputTensor), context->getOutputShape(kOutputTensor)); default: - NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC"; + NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH"; } } @@ -358,17 +375,21 @@ bool executeTanh(IOperationExecutionContext* context) { using std::placeholders::_1; NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1), - activation::prepare, activation::executeRelu, .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::RELU, _1), + activation::executeRelu, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1), - activation::prepare, activation::executeRelu1, .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::RELU1, _1), + activation::executeRelu1, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1), - activation::prepare, activation::executeRelu6, .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::RELU6, _1), + activation::executeRelu6, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC", std::bind(activation::validate, OperationType::LOGISTIC, _1), - activation::prepare, activation::executeLogistic, - .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::LOGISTIC, _1), + activation::executeLogistic, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1), - activation::prepare, activation::executeTanh, .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::TANH, _1), + activation::executeTanh, .allowZeroSizedInput = true); } // namespace nn } // namespace android diff --git a/nn/common/operations/RoiPooling.cpp b/nn/common/operations/RoiPooling.cpp index 7edbd8123..37914fdc9 100644 --- a/nn/common/operations/RoiPooling.cpp +++ b/nn/common/operations/RoiPooling.cpp @@ -235,8 +235,7 @@ bool prepare(IOperationExecutionContext* context) { NN_RET_CHECK_EQ(roiShape.offset, 0); } - Shape output = context->getOutputShape(kOutputTensor); - output.type = input.type; + Shape output = input; if (useNchw) { output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight), static_cast<uint32_t>(outputWidth)}; diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp index 8866a5a05..35b6788b4 100644 --- a/nn/runtime/ExecutionBuilder.cpp +++ b/nn/runtime/ExecutionBuilder.cpp @@ -898,7 +898,10 @@ int StepExecutor::startComputeOnDevice( // in the design document. sp<ExecutionCallback> executionCallback = new ExecutionCallback(); - if (burstController != nullptr) { + // compute using burst if present + const bool burstCompute = (burstController != nullptr); + bool burstFallback = false; + if (burstCompute) { std::vector<intptr_t> memoryIds; memoryIds.reserve(mMemories.size()); for (const Memory* memory : mMemories) { @@ -906,34 +909,46 @@ int StepExecutor::startComputeOnDevice( memoryIds.push_back(memory->getKey()); } - VLOG(EXECUTION) << "Before ExecutionBurstController->compute() " + VLOG(EXECUTION) << "Before ExecutionBurstController->tryCompute() " << SHOW_IF_DEBUG(toString(request)); - auto burstExecuteResult = - burstController->compute(request, measureTiming(mExecutionBuilder), memoryIds); - executionCallback->notify(std::get<0>(burstExecuteResult), std::get<1>(burstExecuteResult), - std::get<2>(burstExecuteResult)); - } else if (DeviceManager::get()->syncExecHal()) { - VLOG(EXECUTION) << "Before mPreparedModel->executeSynchronously() " - << SHOW_IF_DEBUG(toString(request)); - auto syncExecuteResult = - mPreparedModel->executeSynchronously(request, measureTiming(mExecutionBuilder)); - executionCallback->notify(std::get<0>(syncExecuteResult), std::get<1>(syncExecuteResult), - std::get<2>(syncExecuteResult)); - } else { - VLOG(EXECUTION) << "Before mPreparedModel->execute() " << SHOW_IF_DEBUG(toString(request)); - // Execute. - // TODO: What happens to the Callback if the service dies abnormally - // -- won't that keep the Callback live forever, because the service - // never has the opportunity to bump the reference count down? Or - // maybe the HIDL infrastructure handles this magically? At worst, - // it seems like this is a small memory leak, if the Callback stays - // alive forever. - Return<ErrorStatus> executeStatus = mPreparedModel->execute( - request, measureTiming(mExecutionBuilder), executionCallback); - if (!executeStatus.isOk() || executeStatus != ErrorStatus::NONE) { - VLOG(EXECUTION) << "**Execute launch failed**"; - return executeStatus.isOk() ? convertErrorStatusToResultCode(executeStatus) - : ANEURALNETWORKS_OP_FAILED; + auto [status, outputShapes, timing, fallback] = + burstController->tryCompute(request, measureTiming(mExecutionBuilder), memoryIds); + + burstFallback = fallback; + if (!fallback) { + executionCallback->notify(status, outputShapes, timing); + } + } + + // compute from IPreparedModel if either: + // (1) burst was not supplied, or + // (2) the burst execution failed and requested a fallback execution + if (!burstCompute || burstFallback) { + if (DeviceManager::get()->syncExecHal()) { + VLOG(EXECUTION) << "Before mPreparedModel->executeSynchronously() " + << SHOW_IF_DEBUG(toString(request)); + auto syncExecuteResult = + mPreparedModel->executeSynchronously(request, measureTiming(mExecutionBuilder)); + executionCallback->notify(std::get<0>(syncExecuteResult), + std::get<1>(syncExecuteResult), + std::get<2>(syncExecuteResult)); + } else { + VLOG(EXECUTION) << "Before mPreparedModel->execute() " + << SHOW_IF_DEBUG(toString(request)); + // Execute. + // TODO: What happens to the Callback if the service dies abnormally + // -- won't that keep the Callback live forever, because the service + // never has the opportunity to bump the reference count down? Or + // maybe the HIDL infrastructure handles this magically? At worst, + // it seems like this is a small memory leak, if the Callback stays + // alive forever. + Return<ErrorStatus> executeStatus = mPreparedModel->execute( + request, measureTiming(mExecutionBuilder), executionCallback); + if (!executeStatus.isOk() || executeStatus != ErrorStatus::NONE) { + VLOG(EXECUTION) << "**Execute launch failed**"; + return executeStatus.isOk() ? convertErrorStatusToResultCode(executeStatus) + : ANEURALNETWORKS_OP_FAILED; + } } } diff --git a/nn/runtime/test/TestCompliance.cpp b/nn/runtime/test/TestCompliance.cpp index 93918c803..52764154c 100644 --- a/nn/runtime/test/TestCompliance.cpp +++ b/nn/runtime/test/TestCompliance.cpp @@ -27,12 +27,15 @@ namespace compliance_test { using namespace ::android::nn; using HidlModel = V1_2::Model; using WrapperModel = test_wrapper::Model; +using WrapperOperandType = test_wrapper::OperandType; +using WrapperType = test_wrapper::Type; // Creates a HIDL model from a creator of the wrapper model. static HidlModel createHidlModel(std::function<void(WrapperModel*)> createModel) { HidlModel hidlModel; WrapperModel wrapperModel; createModel(&wrapperModel); + EXPECT_EQ(wrapperModel.finish(), test_wrapper::Result::NO_ERROR); ModelBuilder* modelBuilder = reinterpret_cast<ModelBuilder*>(wrapperModel.getHandle()); modelBuilder->setHidlModel(&hidlModel); return hidlModel; @@ -56,4 +59,89 @@ void ComplianceTest::testAvailableSinceV1_0(std::function<void(WrapperModel*)> c ASSERT_TRUE(compliantWithV1_0(model)); } +static const WrapperOperandType kTypeTensorFloat(WrapperType::TENSOR_FLOAT32, {1}); +static const WrapperOperandType kTypeTensorFloatRank0(WrapperType::TENSOR_FLOAT32, {}); +static const WrapperOperandType kTypeInt32(WrapperType::INT32, {}); + +TEST_F(ComplianceTest, Rank0TensorModelInput) { + int32_t act_init = 0; + // A simple ADD operation: op1 ADD op2 = op3, with op1 and op2 of rank 0. + testAvailableSinceV1_2([&act_init](WrapperModel* model) { + auto op1 = model->addOperand(&kTypeTensorFloatRank0); + auto op2 = model->addOperand(&kTypeTensorFloatRank0); + auto act = model->addOperand(&kTypeInt32); + auto op3 = model->addOperand(&kTypeTensorFloat); + model->setOperandValue(act, &act_init, sizeof(act_init)); + model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3}); + model->identifyInputsAndOutputs({op1, op2}, {op3}); + assert(model->isValid()); + }); +} + +TEST_F(ComplianceTest, Rank0TensorModelOutput) { + int32_t act_init = 0; + // A simple ADD operation: op1 ADD op2 = op3, with op3 of rank 0. + testAvailableSinceV1_2([&act_init](WrapperModel* model) { + auto op1 = model->addOperand(&kTypeTensorFloat); + auto op2 = model->addOperand(&kTypeTensorFloat); + auto act = model->addOperand(&kTypeInt32); + auto op3 = model->addOperand(&kTypeTensorFloatRank0); + model->setOperandValue(act, &act_init, sizeof(act_init)); + model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3}); + model->identifyInputsAndOutputs({op1, op2}, {op3}); + assert(model->isValid()); + }); +} + +TEST_F(ComplianceTest, Rank0TensorTemporaryVariable) { + int32_t act_init = 0; + // Two ADD operations: op1 ADD op2 = op3, op3 ADD op4 = op5, with op3 of rank 0. + testAvailableSinceV1_2([&act_init](WrapperModel* model) { + auto op1 = model->addOperand(&kTypeTensorFloat); + auto op2 = model->addOperand(&kTypeTensorFloat); + auto op3 = model->addOperand(&kTypeTensorFloatRank0); + auto op4 = model->addOperand(&kTypeTensorFloat); + auto op5 = model->addOperand(&kTypeTensorFloat); + auto act = model->addOperand(&kTypeInt32); + model->setOperandValue(act, &act_init, sizeof(act_init)); + model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3}); + model->addOperation(ANEURALNETWORKS_ADD, {op3, op4, act}, {op5}); + model->identifyInputsAndOutputs({op1, op2, op4}, {op5}); + assert(model->isValid()); + }); +} + +TEST_F(ComplianceTest, HardwareBuffer) { + const size_t memorySize = 20; + AHardwareBuffer_Desc desc{ + .width = memorySize, + .height = 1, + .layers = 1, + .format = AHARDWAREBUFFER_FORMAT_BLOB, + .usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN, + }; + + AHardwareBuffer* buffer = nullptr; + ASSERT_EQ(AHardwareBuffer_allocate(&desc, &buffer), 0); + test_wrapper::Memory memory(buffer); + ASSERT_TRUE(memory.isValid()); + + int32_t act_init = 0; + + // A simple ADD operation: op1 ADD op2 = op3, with op2 using a const hardware buffer. + testAvailableSinceV1_2([&memory, &act_init](WrapperModel* model) { + auto op1 = model->addOperand(&kTypeTensorFloat); + auto op2 = model->addOperand(&kTypeTensorFloat); + auto act = model->addOperand(&kTypeInt32); + auto op3 = model->addOperand(&kTypeTensorFloat); + model->setOperandValueFromMemory(op2, &memory, 0, sizeof(float)); + model->setOperandValue(act, &act_init, sizeof(act_init)); + model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3}); + model->identifyInputsAndOutputs({op1}, {op3}); + assert(model->isValid()); + }); + + AHardwareBuffer_release(buffer); +} + } // namespace compliance_test diff --git a/nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp b/nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp index afc19c3ee..101e2f707 100644 --- a/nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp +++ b/nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp @@ -57,6 +57,10 @@ static void roiConstructor(Type, uint32_t rank, RandomOperation* op) { } else { op->outputs[0]->dimensions = {outBatch, outHeight, outWidth, outDepth}; } + + if (op->opType == ANEURALNETWORKS_ROI_POOLING) { + setSameQuantization(op->outputs[0], op->inputs[0]); + } } template <typename T> diff --git a/nn/runtime/test/generated/models/logistic_v1_2.model.cpp b/nn/runtime/test/generated/models/logistic_v1_2.model.cpp index 3c5187861..170d881eb 100644 --- a/nn/runtime/test/generated/models/logistic_v1_2.model.cpp +++ b/nn/runtime/test/generated/models/logistic_v1_2.model.cpp @@ -252,7 +252,7 @@ void CreateModel_zero_sized_quant8(Model *model) { OperandType type10(Type::BOOL, {}); OperandType type14(Type::TENSOR_QUANT8_ASYMM, {0, 2, 2, 1}, 0.1f, 128); OperandType type15(Type::TENSOR_QUANT8_ASYMM, {1, 1, 1, 1}, 0.1f, 128); - OperandType type16(Type::TENSOR_QUANT8_ASYMM, {0, 2, 2, 1}, 0.00390625f, 128); + OperandType type16(Type::TENSOR_QUANT8_ASYMM, {0, 2, 2, 1}, 0.00390625f, 0); OperandType type17(Type::TENSOR_QUANT16_ASYMM, {1, 8}, 0.125f, 0); OperandType type18(Type::TENSOR_QUANT16_ASYMM, {0, 4}, 0.125f, 0); OperandType type19(Type::TENSOR_QUANT8_ASYMM, {1, 2}, 0.1f, 128); @@ -597,7 +597,7 @@ void CreateModel_zero_sized_dynamic_output_shape_quant8(Model *model) { OperandType type18(Type::TENSOR_QUANT16_ASYMM, {0, 4}, 0.125f, 0); OperandType type19(Type::TENSOR_QUANT8_ASYMM, {1, 2}, 0.1f, 128); OperandType type20(Type::TENSOR_QUANT8_ASYMM, {0}, 0.1f, 128); - OperandType type29(Type::TENSOR_QUANT8_ASYMM, {0, 0, 0, 0}, 0.00390625f, 128); + OperandType type29(Type::TENSOR_QUANT8_ASYMM, {0, 0, 0, 0}, 0.00390625f, 0); OperandType type5(Type::TENSOR_INT32, {0}); OperandType type7(Type::TENSOR_INT32, {1}); OperandType type8(Type::FLOAT32, {}); diff --git a/nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp b/nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp index cd0f47a49..fc9419281 100644 --- a/nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp +++ b/nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp @@ -915,7 +915,7 @@ Model createTestModel_zero_sized_quant8() { .dimensions = {0, 2, 2, 1}, .numberOfConsumers = 0, .scale = 0.00390625f, - .zeroPoint = 128, + .zeroPoint = 0, .lifetime = OperandLifeTime::MODEL_OUTPUT, .location = {.poolIndex = 0, .offset = 0, .length = 0}, } @@ -1924,7 +1924,7 @@ Model createTestModel_zero_sized_dynamic_output_shape_quant8() { .dimensions = {0, 0, 0, 0}, .numberOfConsumers = 0, .scale = 0.00390625f, - .zeroPoint = 128, + .zeroPoint = 0, .lifetime = OperandLifeTime::MODEL_OUTPUT, .location = {.poolIndex = 0, .offset = 0, .length = 0}, } diff --git a/nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py b/nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py index f8037b1c5..fe91a814d 100644 --- a/nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py +++ b/nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py @@ -82,7 +82,7 @@ quant8 = DataTypeConverter().Identify({ tmp1: ("TENSOR_QUANT16_ASYMM", 0.125, 0), i1: ("TENSOR_QUANT8_ASYMM", 0.1, 128), zero_sized: ("TENSOR_QUANT8_ASYMM", 0.1, 128), - o3: ("TENSOR_QUANT8_ASYMM", 1.0 / 256, 128) + o3: ("TENSOR_QUANT8_ASYMM", 1.0 / 256, 0) }) # Create test case with dummy values. |