summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorandroid-build-team Robot <android-build-team-robot@google.com>2019-05-16 03:14:26 +0000
committerandroid-build-team Robot <android-build-team-robot@google.com>2019-05-16 03:14:26 +0000
commit2af3599f21dda8dcd4aa1743bf95e17f708c2b67 (patch)
treeea9e0e63cadd83f9677c64801622c998f3e2538e
parentc26781743003e5b928cf6b6b57217ec1851bd0e2 (diff)
parent4363637cbd8ca9b782a9ab5e2bfdc95404649910 (diff)
downloadml-2af3599f21dda8dcd4aa1743bf95e17f708c2b67.tar.gz
Snap for 5571215 from 4363637cbd8ca9b782a9ab5e2bfdc95404649910 to qt-release
Change-Id: I33bd24dfa13da23692ccd985e2ee2fd990953c88
-rw-r--r--nn/common/ExecutionBurstController.cpp126
-rw-r--r--nn/common/Utils.cpp34
-rw-r--r--nn/common/ValidateHal.cpp93
-rw-r--r--nn/common/include/ExecutionBurstController.h54
-rw-r--r--nn/common/include/Utils.h2
-rw-r--r--nn/common/include/ValidateHal.h13
-rw-r--r--nn/common/operations/Activation.cpp43
-rw-r--r--nn/common/operations/RoiPooling.cpp3
-rw-r--r--nn/runtime/ExecutionBuilder.cpp71
-rw-r--r--nn/runtime/test/TestCompliance.cpp88
-rw-r--r--nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp4
-rw-r--r--nn/runtime/test/generated/models/logistic_v1_2.model.cpp4
-rw-r--r--nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp4
-rw-r--r--nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py2
14 files changed, 407 insertions, 134 deletions
diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp
index be8be7105..55ef9ad29 100644
--- a/nn/common/ExecutionBurstController.cpp
+++ b/nn/common/ExecutionBurstController.cpp
@@ -34,6 +34,23 @@ using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
std::numeric_limits<uint64_t>::max()};
+class BurstContextDeathHandler : public hardware::hidl_death_recipient {
+ public:
+ using Callback = std::function<void()>;
+
+ BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
+ CHECK(onDeathCallback != nullptr);
+ }
+
+ void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
+ LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
+ mOnDeathCallback();
+ }
+
+ private:
+ const Callback mOnDeathCallback;
+};
+
} // anonymous namespace
// serialize a request into a packet
@@ -229,13 +246,13 @@ ResultChannelReceiver::getBlocking() {
}
void ResultChannelReceiver::invalidate() {
- mTeardown = true;
+ mValid = false;
// force unblock
- // ExecutionBurstServer is by default waiting on a request packet. If the
- // client process destroys its burst object, the server will still be
- // waiting on the futex (assuming mBlocking is true). This force unblock
- // wakes up any thread waiting on the futex.
+ // ExecutionBurstController waits on a result packet after sending a
+ // request. If the driver containing ExecutionBurstServer crashes, the
+ // controller will still be waiting on the futex (assuming mBlocking is
+ // true). This force unblock wakes up any thread waiting on the futex.
if (mBlocking) {
// TODO: look for a different/better way to signal/notify the futex to
// wake up any thread waiting on it
@@ -249,7 +266,7 @@ void ResultChannelReceiver::invalidate() {
std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
using discriminator = FmqResultDatum::hidl_discriminator;
- if (mTeardown) {
+ if (!mValid) {
return std::nullopt;
}
@@ -259,7 +276,7 @@ std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlock
if (mBlocking) {
success = mFmqResultChannel->readBlocking(&datum, 1);
} else {
- while ((success = !mTeardown.load(std::memory_order_relaxed)) &&
+ while ((success = mValid.load(std::memory_order_relaxed)) &&
!mFmqResultChannel->read(&datum, 1)) {
}
}
@@ -276,8 +293,7 @@ std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlock
std::memcpy(&packet.front(), &datum, sizeof(datum));
success &= mFmqResultChannel->read(packet.data() + 1, count);
- // terminate loop
- if (mTeardown) {
+ if (!mValid) {
return std::nullopt;
}
@@ -315,6 +331,10 @@ bool RequestChannelSender::send(const Request& request, MeasureTiming measure,
}
bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
+ if (!mValid) {
+ return false;
+ }
+
if (packet.size() > mFmqRequestChannel->availableToWrite()) {
LOG(ERROR)
<< "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
@@ -328,6 +348,10 @@ bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet
}
}
+void RequestChannelSender::invalidate() {
+ mValid = false;
+}
+
Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
const hidl_vec<int32_t>& slots, getMemories_cb cb) {
std::lock_guard<std::mutex> guard(mMutex);
@@ -420,19 +444,20 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
// create callback object
sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
- if (callback == nullptr) {
- LOG(ERROR) << "ExecutionBurstController::create failed to create callback";
- return nullptr;
- }
// create FMQ objects
- auto [fmqRequestChannel, fmqRequestDescriptor] =
+ auto [requestChannelSenderTemp, requestChannelDescriptor] =
RequestChannelSender::create(kExecutionBurstChannelLength, blocking);
- auto [fmqResultChannel, fmqResultDescriptor] =
+ auto [resultChannelReceiverTemp, resultChannelDescriptor] =
ResultChannelReceiver::create(kExecutionBurstChannelLength, blocking);
+ std::shared_ptr<RequestChannelSender> requestChannelSender =
+ std::move(requestChannelSenderTemp);
+ std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
+ std::move(resultChannelReceiverTemp);
// check FMQ objects
- if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestDescriptor || !fmqResultDescriptor) {
+ if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
+ !resultChannelDescriptor) {
LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
return nullptr;
}
@@ -441,7 +466,7 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
ErrorStatus errorStatus;
sp<IBurstContext> burstContext;
const Return<void> ret = preparedModel->configureExecutionBurst(
- callback, *fmqRequestDescriptor, *fmqResultDescriptor,
+ callback, *requestChannelDescriptor, *resultChannelDescriptor,
[&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
errorStatus = status;
burstContext = context;
@@ -463,22 +488,61 @@ std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
return nullptr;
}
+ // create death handler object
+ BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
+ resultChannelReceiver] {
+ requestChannelSender->invalidate();
+ resultChannelReceiver->invalidate();
+ };
+ const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
+
+ // linkToDeath registers a callback that will be invoked on service death to
+ // proactively handle service crashes. If the linkToDeath call fails,
+ // asynchronous calls are susceptible to hangs if the service crashes before
+ // providing the response.
+ const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
+ if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
+ LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
+ "for the IBurstContext object.";
+ return nullptr;
+ }
+
// make and return controller
- return std::make_unique<ExecutionBurstController>(
- std::move(fmqRequestChannel), std::move(fmqResultChannel), burstContext, callback);
+ return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
+ burstContext, callback, deathHandler);
}
ExecutionBurstController::ExecutionBurstController(
- std::unique_ptr<RequestChannelSender> requestChannelSender,
- std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
- const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback)
- : mRequestChannelSender(std::move(requestChannelSender)),
- mResultChannelReceiver(std::move(resultChannelReceiver)),
+ const std::shared_ptr<RequestChannelSender>& requestChannelSender,
+ const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
+ const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
+ const sp<hardware::hidl_death_recipient>& deathHandler)
+ : mRequestChannelSender(requestChannelSender),
+ mResultChannelReceiver(resultChannelReceiver),
mBurstContext(burstContext),
- mMemoryCache(callback) {}
+ mMemoryCache(callback),
+ mDeathHandler(deathHandler) {}
+
+ExecutionBurstController::~ExecutionBurstController() {
+ // It is safe to ignore any errors resulting from this unlinkToDeath call
+ // because the ExecutionBurstController object is already being destroyed
+ // and its underlying IBurstContext object is no longer being used by the NN
+ // runtime.
+ if (mDeathHandler) {
+ mBurstContext->unlinkToDeath(mDeathHandler).isOk();
+ }
+}
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute(
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
+ auto [status, outputShapes, timing, fallback] = tryCompute(request, measure, memoryIds);
+ (void)fallback; // ignore fallback field
+ return {status, std::move(outputShapes), timing};
+}
+
+std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool>
+ExecutionBurstController::tryCompute(const Request& request, MeasureTiming measure,
+ const std::vector<intptr_t>& memoryIds) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
std::lock_guard<std::mutex> guard(mMutex);
@@ -488,16 +552,22 @@ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstControll
const bool success = mRequestChannelSender->send(request, measure, slots);
if (!success) {
LOG(ERROR) << "Error sending FMQ packet";
- return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
+ // only use fallback execution path if the packet could not be sent
+ return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/true};
}
// get result packet
const auto result = mResultChannelReceiver->getBlocking();
if (!result) {
LOG(ERROR) << "Error retrieving FMQ packet";
- return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming};
+ // only use fallback execution path if the packet could not be sent
+ return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/false};
}
- return *result;
+
+ // unpack results and return (only use fallback execution path if the
+ // packet could not be sent)
+ auto [status, outputShapes, timing] = std::move(*result);
+ return {status, std::move(outputShapes), timing, /*fallback=*/false};
}
void ExecutionBurstController::freeMemory(intptr_t key) {
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index a1b3b6a7e..731127a26 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -2076,6 +2076,12 @@ static hidl_vec<V1_1::Operation> convertToV1_1(const hidl_vec<V1_0::Operation>&
return result;
}
+bool compliantWithV1_0(const V1_2::Operand& operand) {
+ return validOperandType(static_cast<V1_0::OperandType>(operand.type)) &&
+ (nonExtensionOperandTypeIsScalar(static_cast<int>(operand.type)) ||
+ operand.dimensions.size() != 0);
+}
+
V1_0::Model convertToV1_0(const V1_0::Model& model) {
return model;
}
@@ -2119,7 +2125,33 @@ void logModelToInfo(const V1_2::Model& model) {
static bool compliantWith(HalVersion version, const V1_2::Model& model,
std::set<uint32_t>* noncompliantOperations) {
- auto localValidateOperation = [&model, version](const V1_2::Operation& op) {
+ if (version >= HalVersion::V1_2) return true;
+
+ // A boolean vector indicating whether each pool is compliant with the target HAL version.
+ std::vector<bool> isPoolCompliant(model.pools.size(), false);
+ std::transform(model.pools.begin(), model.pools.end(), isPoolCompliant.begin(),
+ [version](const hidl_memory& pool) { return validatePool(pool, version); });
+
+ // A boolean vector indicating whether each operand is compliant with the target HAL version.
+ std::vector<bool> isOperandCompliant(model.operands.size(), false);
+ std::transform(model.operands.begin(), model.operands.end(), isOperandCompliant.begin(),
+ [&isPoolCompliant](const V1_2::Operand& op) {
+ // There is no V1_1::Operand -- both V1_0::Model and V1_1::Model use
+ // V1_0::Operand.
+ return compliantWithV1_0(op) &&
+ !(op.lifetime == OperandLifeTime::CONSTANT_REFERENCE &&
+ !isPoolCompliant[op.location.poolIndex]);
+ });
+
+ auto allOperandsCompliant = [&isOperandCompliant](const hidl_vec<uint32_t>& indices) {
+ return std::all_of(
+ indices.begin(), indices.end(),
+ [&isOperandCompliant](const uint32_t ind) { return isOperandCompliant[ind]; });
+ };
+
+ auto localValidateOperation = [&model, version,
+ &allOperandsCompliant](const V1_2::Operation& op) {
+ if (!allOperandsCompliant(op.inputs) || !allOperandsCompliant(op.outputs)) return false;
int error = validateOperation(
static_cast<int32_t>(op.type), op.inputs.size(),
op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
diff --git a/nn/common/ValidateHal.cpp b/nn/common/ValidateHal.cpp
index 1015d922d..421730a1d 100644
--- a/nn/common/ValidateHal.cpp
+++ b/nn/common/ValidateHal.cpp
@@ -27,6 +27,21 @@
namespace android {
namespace nn {
+template <class T_Model>
+struct ModelToHalVersion;
+template <>
+struct ModelToHalVersion<V1_0::Model> {
+ static constexpr HalVersion version = HalVersion::V1_0;
+};
+template <>
+struct ModelToHalVersion<V1_1::Model> {
+ static constexpr HalVersion version = HalVersion::V1_1;
+};
+template <>
+struct ModelToHalVersion<V1_2::Model> {
+ static constexpr HalVersion version = HalVersion::V1_2;
+};
+
class MemoryAccessVerifier {
public:
MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
@@ -418,22 +433,26 @@ static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
return true;
}
-static bool validatePools(const hidl_vec<hidl_memory>& pools) {
- for (const hidl_memory& memory : pools) {
- const auto& name = memory.name();
- if (name != "ashmem" && name != "mmap_fd" && name != "hardware_buffer_blob" &&
- name != "hardware_buffer") {
- LOG(ERROR) << "Unsupported memory type " << name;
- return false;
- }
- if (memory.handle() == nullptr) {
- LOG(ERROR) << "Memory of type " << name << " is null";
- return false;
- }
+bool validatePool(const hidl_memory& pool, HalVersion ver) {
+ const auto& name = pool.name();
+ if (name != "ashmem" && name != "mmap_fd" &&
+ ((ver < HalVersion::V1_2) ||
+ (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
+ LOG(ERROR) << "Unsupported memory type " << name;
+ return false;
+ }
+ if (pool.handle() == nullptr) {
+ LOG(ERROR) << "Memory of type " << name << " is null";
+ return false;
}
return true;
}
+static bool validatePools(const hidl_vec<hidl_memory>& pools, HalVersion ver) {
+ return std::all_of(pools.begin(), pools.end(),
+ [ver](const hidl_memory& pool) { return validatePool(pool, ver); });
+}
+
static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
const size_t operandCount = operands.size();
@@ -460,10 +479,10 @@ static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
return true;
}
-template <typename VersionedModel>
-static bool validateModelVersioned(const VersionedModel& model, bool allowUnspecifiedRank) {
- NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED,
- "validateModelVersioned");
+template <class T_Model>
+bool validateModel(const T_Model& model) {
+ NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
+ HalVersion version = ModelToHalVersion<T_Model>::version;
if (model.operations.size() == 0 || model.operands.size() == 0) {
LOG(ERROR) << "Invalid empty model.";
return false;
@@ -472,26 +491,18 @@ static bool validateModelVersioned(const VersionedModel& model, bool allowUnspec
// validations we can use operands upcasted to the latest version.
const hidl_vec<Operand> latestVersionOperands = convertToV1_2(model.operands);
return (validateOperands(model.operands, model.operandValues, model.pools,
- allowUnspecifiedRank) &&
+ /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
validateOperations(model.operations, latestVersionOperands) &&
validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
OperandLifeTime::MODEL_INPUT) &&
validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
OperandLifeTime::MODEL_OUTPUT) &&
- validatePools(model.pools));
-}
-
-bool validateModel(const V1_0::Model& model) {
- return validateModelVersioned(model, /*allowUnspecifiedRank=*/false);
+ validatePools(model.pools, version));
}
-bool validateModel(const V1_1::Model& model) {
- return validateModelVersioned(model, /*allowUnspecifiedRank=*/false);
-}
-
-bool validateModel(const V1_2::Model& model) {
- return validateModelVersioned(model, /*allowUnspecifiedRank=*/true);
-}
+template bool validateModel<V1_0::Model>(const V1_0::Model& model);
+template bool validateModel<V1_1::Model>(const V1_1::Model& model);
+template bool validateModel<V1_2::Model>(const V1_2::Model& model);
// Validates the arguments of a request. type is either "input" or "output" and is used
// for printing error messages. The operandIndexes is the appropriate array of input
@@ -572,29 +583,21 @@ static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArg
return true;
}
-template <typename VersionedModel>
-static bool validateRequestVersioned(const Request& request, const VersionedModel& model,
- bool allowDynamicOutputShape) {
+template <class T_Model>
+bool validateRequest(const Request& request, const T_Model& model) {
+ HalVersion version = ModelToHalVersion<T_Model>::version;
return (validateRequestArguments(request.inputs, model.inputIndexes,
convertToV1_2(model.operands), request.pools,
/*allowUnspecified=*/false, "input") &&
validateRequestArguments(request.outputs, model.outputIndexes,
convertToV1_2(model.operands), request.pools,
- /*allowUnspecified=*/allowDynamicOutputShape, "output") &&
- validatePools(request.pools));
+ /*allowUnspecified=*/version >= HalVersion::V1_2, "output") &&
+ validatePools(request.pools, version));
}
-bool validateRequest(const Request& request, const V1_0::Model& model) {
- return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/false);
-}
-
-bool validateRequest(const Request& request, const V1_1::Model& model) {
- return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/false);
-}
-
-bool validateRequest(const Request& request, const V1_2::Model& model) {
- return validateRequestVersioned(request, model, /*allowDynamicOutputShape=*/true);
-}
+template bool validateRequest<V1_0::Model>(const Request& request, const V1_0::Model& model);
+template bool validateRequest<V1_1::Model>(const Request& request, const V1_1::Model& model);
+template bool validateRequest<V1_2::Model>(const Request& request, const V1_2::Model& model);
bool validateExecutionPreference(ExecutionPreference preference) {
return preference == ExecutionPreference::LOW_POWER ||
diff --git a/nn/common/include/ExecutionBurstController.h b/nn/common/include/ExecutionBurstController.h
index 3397d9673..030ea4d2d 100644
--- a/nn/common/include/ExecutionBurstController.h
+++ b/nn/common/include/ExecutionBurstController.h
@@ -117,7 +117,7 @@ class ResultChannelReceiver {
private:
const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
- std::atomic<bool> mTeardown{false};
+ std::atomic<bool> mValid{true};
const bool mBlocking;
};
@@ -157,6 +157,13 @@ class RequestChannelSender {
*/
bool send(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots);
+ /**
+ * Method to mark the channel as invalid, causing all future calls to
+ * RequestChannelSender::send to immediately return false without attempting
+ * to send a message across the FMQ.
+ */
+ void invalidate();
+
// prefer calling RequestChannelSender::send
bool sendPacket(const std::vector<FmqRequestDatum>& packet);
@@ -164,6 +171,7 @@ class RequestChannelSender {
private:
const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
+ std::atomic<bool> mValid{true};
const bool mBlocking;
};
@@ -259,10 +267,15 @@ class ExecutionBurstController {
static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel,
bool blocking);
- ExecutionBurstController(std::unique_ptr<RequestChannelSender> requestChannelSender,
- std::unique_ptr<ResultChannelReceiver> resultChannelReceiver,
+ // prefer calling ExecutionBurstController::create
+ ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender,
+ const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
const sp<IBurstContext>& burstContext,
- const sp<ExecutionBurstCallback>& callback);
+ const sp<ExecutionBurstCallback>& callback,
+ const sp<hardware::hidl_death_recipient>& deathHandler = nullptr);
+
+ // explicit destructor to unregister the death recipient
+ ~ExecutionBurstController();
/**
* Execute a request on a model.
@@ -271,12 +284,36 @@ class ExecutionBurstController {
* @param measure Whether to collect timing measurements, either YES or NO
* @param memoryIds Identifiers corresponding to each memory object in the
* request's pools.
- * @return status and output shape of the execution and any execution time
- * measurements.
+ * @return A tuple of:
+ * - status of the execution
+ * - dynamic output shapes from the execution
+ * - any execution time measurements of the execution
*/
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute(
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
+ // TODO: combine "compute" and "tryCompute" back into a single function.
+ // "tryCompute" was created later to return the "fallback" boolean. This
+ // could not be done directly in "compute" because the VTS test cases (which
+ // test burst using "compute") had already been locked down and could not be
+ // changed.
+ /**
+ * Execute a request on a model.
+ *
+ * @param request Arguments to be executed on a model.
+ * @param measure Whether to collect timing measurements, either YES or NO
+ * @param memoryIds Identifiers corresponding to each memory object in the
+ * request's pools.
+ * @return A tuple of:
+ * - status of the execution
+ * - dynamic output shapes from the execution
+ * - any execution time measurements of the execution
+ * - whether or not a failed burst execution should be re-run using a
+ * different path (e.g., IPreparedModel::executeSynchronously)
+ */
+ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> tryCompute(
+ const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
+
/**
* Propagate a user's freeing of memory to the service.
*
@@ -286,10 +323,11 @@ class ExecutionBurstController {
private:
std::mutex mMutex;
- const std::unique_ptr<RequestChannelSender> mRequestChannelSender;
- const std::unique_ptr<ResultChannelReceiver> mResultChannelReceiver;
+ const std::shared_ptr<RequestChannelSender> mRequestChannelSender;
+ const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver;
const sp<IBurstContext> mBurstContext;
const sp<ExecutionBurstCallback> mMemoryCache;
+ const sp<hardware::hidl_death_recipient> mDeathHandler;
};
} // namespace android::nn
diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h
index 64472f12a..bf6cffda4 100644
--- a/nn/common/include/Utils.h
+++ b/nn/common/include/Utils.h
@@ -315,6 +315,8 @@ bool compliantWithV1_2(const V1_0::Capabilities& capabilities);
bool compliantWithV1_2(const V1_1::Capabilities& capabilities);
bool compliantWithV1_2(const V1_2::Capabilities& capabilities);
+bool compliantWithV1_0(const V1_2::Operand& operand);
+
// If noncompliantOperations != nullptr, then
// precondition: noncompliantOperations->empty()
// postcondition: *noncompliantOperations consists of the indices of the noncompliant
diff --git a/nn/common/include/ValidateHal.h b/nn/common/include/ValidateHal.h
index c953d8a10..4275a24a7 100644
--- a/nn/common/include/ValidateHal.h
+++ b/nn/common/include/ValidateHal.h
@@ -36,17 +36,15 @@ enum class HalVersion : int32_t {
// IMPORTANT: This function cannot validate that OEM operation and operands
// are correctly defined, as these are specific to each implementation.
// Each driver should do their own validation of OEM types.
-bool validateModel(const V1_0::Model& model);
-bool validateModel(const V1_1::Model& model);
-bool validateModel(const V1_2::Model& model);
+template <class T_Model>
+bool validateModel(const T_Model& model);
// Verfies that the request for the given model is valid.
// IMPORTANT: This function cannot validate that OEM operation and operands
// are correctly defined, as these are specific to each implementation.
// Each driver should do their own validation of OEM types.
-bool validateRequest(const Request& request, const V1_0::Model& model);
-bool validateRequest(const Request& request, const V1_1::Model& model);
-bool validateRequest(const Request& request, const V1_2::Model& model);
+template <class T_Model>
+bool validateRequest(const Request& request, const T_Model& model);
// Verfies that the execution preference is valid.
bool validateExecutionPreference(ExecutionPreference preference);
@@ -58,6 +56,9 @@ bool validOperationType(V1_2::OperationType operation);
bool validOperandType(V1_0::OperandType operand);
bool validOperandType(V1_2::OperandType operand);
+// Verfies that the memory pool is valid in the specified HAL version.
+bool validatePool(const hidl_memory& pool, HalVersion ver = HalVersion::LATEST);
+
} // namespace nn
} // namespace android
diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp
index 491226e60..f85f6b4bf 100644
--- a/nn/common/operations/Activation.cpp
+++ b/nn/common/operations/Activation.cpp
@@ -226,11 +226,28 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
}
-bool prepare(IOperationExecutionContext* context) {
+bool prepare(OperationType opType, IOperationExecutionContext* context) {
Shape input = context->getInputShape(kInputTensor);
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
- Shape output = context->getOutputShape(kOutputTensor);
- output.dimensions = input.dimensions;
+ Shape output = input;
+ if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
+ switch (opType) {
+ case OperationType::RELU:
+ case OperationType::RELU1:
+ case OperationType::RELU6:
+ break;
+ case OperationType::LOGISTIC:
+ output.scale = 1.f / 256;
+ output.offset = 0;
+ break;
+ case OperationType::TANH:
+ output.scale = 1.f / 128;
+ output.offset = 128;
+ break;
+ default:
+ NN_RET_CHECK_FAIL() << "Unsupported operation type";
+ }
+ }
return context->setOutputShape(kOutputTensor, output);
}
@@ -326,7 +343,7 @@ bool executeLogistic(IOperationExecutionContext* context) {
context->getOutputBuffer<uint8_t>(kOutputTensor),
context->getOutputShape(kOutputTensor));
default:
- NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
+ NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
}
}
@@ -350,7 +367,7 @@ bool executeTanh(IOperationExecutionContext* context) {
context->getOutputBuffer<uint8_t>(kOutputTensor),
context->getOutputShape(kOutputTensor));
default:
- NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
+ NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
}
}
@@ -358,17 +375,21 @@ bool executeTanh(IOperationExecutionContext* context) {
using std::placeholders::_1;
NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1),
- activation::prepare, activation::executeRelu, .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::RELU, _1),
+ activation::executeRelu, .allowZeroSizedInput = true);
NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1),
- activation::prepare, activation::executeRelu1, .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::RELU1, _1),
+ activation::executeRelu1, .allowZeroSizedInput = true);
NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1),
- activation::prepare, activation::executeRelu6, .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::RELU6, _1),
+ activation::executeRelu6, .allowZeroSizedInput = true);
NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC",
std::bind(activation::validate, OperationType::LOGISTIC, _1),
- activation::prepare, activation::executeLogistic,
- .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::LOGISTIC, _1),
+ activation::executeLogistic, .allowZeroSizedInput = true);
NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1),
- activation::prepare, activation::executeTanh, .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::TANH, _1),
+ activation::executeTanh, .allowZeroSizedInput = true);
} // namespace nn
} // namespace android
diff --git a/nn/common/operations/RoiPooling.cpp b/nn/common/operations/RoiPooling.cpp
index 7edbd8123..37914fdc9 100644
--- a/nn/common/operations/RoiPooling.cpp
+++ b/nn/common/operations/RoiPooling.cpp
@@ -235,8 +235,7 @@ bool prepare(IOperationExecutionContext* context) {
NN_RET_CHECK_EQ(roiShape.offset, 0);
}
- Shape output = context->getOutputShape(kOutputTensor);
- output.type = input.type;
+ Shape output = input;
if (useNchw) {
output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
static_cast<uint32_t>(outputWidth)};
diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp
index 8866a5a05..35b6788b4 100644
--- a/nn/runtime/ExecutionBuilder.cpp
+++ b/nn/runtime/ExecutionBuilder.cpp
@@ -898,7 +898,10 @@ int StepExecutor::startComputeOnDevice(
// in the design document.
sp<ExecutionCallback> executionCallback = new ExecutionCallback();
- if (burstController != nullptr) {
+ // compute using burst if present
+ const bool burstCompute = (burstController != nullptr);
+ bool burstFallback = false;
+ if (burstCompute) {
std::vector<intptr_t> memoryIds;
memoryIds.reserve(mMemories.size());
for (const Memory* memory : mMemories) {
@@ -906,34 +909,46 @@ int StepExecutor::startComputeOnDevice(
memoryIds.push_back(memory->getKey());
}
- VLOG(EXECUTION) << "Before ExecutionBurstController->compute() "
+ VLOG(EXECUTION) << "Before ExecutionBurstController->tryCompute() "
<< SHOW_IF_DEBUG(toString(request));
- auto burstExecuteResult =
- burstController->compute(request, measureTiming(mExecutionBuilder), memoryIds);
- executionCallback->notify(std::get<0>(burstExecuteResult), std::get<1>(burstExecuteResult),
- std::get<2>(burstExecuteResult));
- } else if (DeviceManager::get()->syncExecHal()) {
- VLOG(EXECUTION) << "Before mPreparedModel->executeSynchronously() "
- << SHOW_IF_DEBUG(toString(request));
- auto syncExecuteResult =
- mPreparedModel->executeSynchronously(request, measureTiming(mExecutionBuilder));
- executionCallback->notify(std::get<0>(syncExecuteResult), std::get<1>(syncExecuteResult),
- std::get<2>(syncExecuteResult));
- } else {
- VLOG(EXECUTION) << "Before mPreparedModel->execute() " << SHOW_IF_DEBUG(toString(request));
- // Execute.
- // TODO: What happens to the Callback if the service dies abnormally
- // -- won't that keep the Callback live forever, because the service
- // never has the opportunity to bump the reference count down? Or
- // maybe the HIDL infrastructure handles this magically? At worst,
- // it seems like this is a small memory leak, if the Callback stays
- // alive forever.
- Return<ErrorStatus> executeStatus = mPreparedModel->execute(
- request, measureTiming(mExecutionBuilder), executionCallback);
- if (!executeStatus.isOk() || executeStatus != ErrorStatus::NONE) {
- VLOG(EXECUTION) << "**Execute launch failed**";
- return executeStatus.isOk() ? convertErrorStatusToResultCode(executeStatus)
- : ANEURALNETWORKS_OP_FAILED;
+ auto [status, outputShapes, timing, fallback] =
+ burstController->tryCompute(request, measureTiming(mExecutionBuilder), memoryIds);
+
+ burstFallback = fallback;
+ if (!fallback) {
+ executionCallback->notify(status, outputShapes, timing);
+ }
+ }
+
+ // compute from IPreparedModel if either:
+ // (1) burst was not supplied, or
+ // (2) the burst execution failed and requested a fallback execution
+ if (!burstCompute || burstFallback) {
+ if (DeviceManager::get()->syncExecHal()) {
+ VLOG(EXECUTION) << "Before mPreparedModel->executeSynchronously() "
+ << SHOW_IF_DEBUG(toString(request));
+ auto syncExecuteResult =
+ mPreparedModel->executeSynchronously(request, measureTiming(mExecutionBuilder));
+ executionCallback->notify(std::get<0>(syncExecuteResult),
+ std::get<1>(syncExecuteResult),
+ std::get<2>(syncExecuteResult));
+ } else {
+ VLOG(EXECUTION) << "Before mPreparedModel->execute() "
+ << SHOW_IF_DEBUG(toString(request));
+ // Execute.
+ // TODO: What happens to the Callback if the service dies abnormally
+ // -- won't that keep the Callback live forever, because the service
+ // never has the opportunity to bump the reference count down? Or
+ // maybe the HIDL infrastructure handles this magically? At worst,
+ // it seems like this is a small memory leak, if the Callback stays
+ // alive forever.
+ Return<ErrorStatus> executeStatus = mPreparedModel->execute(
+ request, measureTiming(mExecutionBuilder), executionCallback);
+ if (!executeStatus.isOk() || executeStatus != ErrorStatus::NONE) {
+ VLOG(EXECUTION) << "**Execute launch failed**";
+ return executeStatus.isOk() ? convertErrorStatusToResultCode(executeStatus)
+ : ANEURALNETWORKS_OP_FAILED;
+ }
}
}
diff --git a/nn/runtime/test/TestCompliance.cpp b/nn/runtime/test/TestCompliance.cpp
index 93918c803..52764154c 100644
--- a/nn/runtime/test/TestCompliance.cpp
+++ b/nn/runtime/test/TestCompliance.cpp
@@ -27,12 +27,15 @@ namespace compliance_test {
using namespace ::android::nn;
using HidlModel = V1_2::Model;
using WrapperModel = test_wrapper::Model;
+using WrapperOperandType = test_wrapper::OperandType;
+using WrapperType = test_wrapper::Type;
// Creates a HIDL model from a creator of the wrapper model.
static HidlModel createHidlModel(std::function<void(WrapperModel*)> createModel) {
HidlModel hidlModel;
WrapperModel wrapperModel;
createModel(&wrapperModel);
+ EXPECT_EQ(wrapperModel.finish(), test_wrapper::Result::NO_ERROR);
ModelBuilder* modelBuilder = reinterpret_cast<ModelBuilder*>(wrapperModel.getHandle());
modelBuilder->setHidlModel(&hidlModel);
return hidlModel;
@@ -56,4 +59,89 @@ void ComplianceTest::testAvailableSinceV1_0(std::function<void(WrapperModel*)> c
ASSERT_TRUE(compliantWithV1_0(model));
}
+static const WrapperOperandType kTypeTensorFloat(WrapperType::TENSOR_FLOAT32, {1});
+static const WrapperOperandType kTypeTensorFloatRank0(WrapperType::TENSOR_FLOAT32, {});
+static const WrapperOperandType kTypeInt32(WrapperType::INT32, {});
+
+TEST_F(ComplianceTest, Rank0TensorModelInput) {
+ int32_t act_init = 0;
+ // A simple ADD operation: op1 ADD op2 = op3, with op1 and op2 of rank 0.
+ testAvailableSinceV1_2([&act_init](WrapperModel* model) {
+ auto op1 = model->addOperand(&kTypeTensorFloatRank0);
+ auto op2 = model->addOperand(&kTypeTensorFloatRank0);
+ auto act = model->addOperand(&kTypeInt32);
+ auto op3 = model->addOperand(&kTypeTensorFloat);
+ model->setOperandValue(act, &act_init, sizeof(act_init));
+ model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3});
+ model->identifyInputsAndOutputs({op1, op2}, {op3});
+ assert(model->isValid());
+ });
+}
+
+TEST_F(ComplianceTest, Rank0TensorModelOutput) {
+ int32_t act_init = 0;
+ // A simple ADD operation: op1 ADD op2 = op3, with op3 of rank 0.
+ testAvailableSinceV1_2([&act_init](WrapperModel* model) {
+ auto op1 = model->addOperand(&kTypeTensorFloat);
+ auto op2 = model->addOperand(&kTypeTensorFloat);
+ auto act = model->addOperand(&kTypeInt32);
+ auto op3 = model->addOperand(&kTypeTensorFloatRank0);
+ model->setOperandValue(act, &act_init, sizeof(act_init));
+ model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3});
+ model->identifyInputsAndOutputs({op1, op2}, {op3});
+ assert(model->isValid());
+ });
+}
+
+TEST_F(ComplianceTest, Rank0TensorTemporaryVariable) {
+ int32_t act_init = 0;
+ // Two ADD operations: op1 ADD op2 = op3, op3 ADD op4 = op5, with op3 of rank 0.
+ testAvailableSinceV1_2([&act_init](WrapperModel* model) {
+ auto op1 = model->addOperand(&kTypeTensorFloat);
+ auto op2 = model->addOperand(&kTypeTensorFloat);
+ auto op3 = model->addOperand(&kTypeTensorFloatRank0);
+ auto op4 = model->addOperand(&kTypeTensorFloat);
+ auto op5 = model->addOperand(&kTypeTensorFloat);
+ auto act = model->addOperand(&kTypeInt32);
+ model->setOperandValue(act, &act_init, sizeof(act_init));
+ model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3});
+ model->addOperation(ANEURALNETWORKS_ADD, {op3, op4, act}, {op5});
+ model->identifyInputsAndOutputs({op1, op2, op4}, {op5});
+ assert(model->isValid());
+ });
+}
+
+TEST_F(ComplianceTest, HardwareBuffer) {
+ const size_t memorySize = 20;
+ AHardwareBuffer_Desc desc{
+ .width = memorySize,
+ .height = 1,
+ .layers = 1,
+ .format = AHARDWAREBUFFER_FORMAT_BLOB,
+ .usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN | AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN,
+ };
+
+ AHardwareBuffer* buffer = nullptr;
+ ASSERT_EQ(AHardwareBuffer_allocate(&desc, &buffer), 0);
+ test_wrapper::Memory memory(buffer);
+ ASSERT_TRUE(memory.isValid());
+
+ int32_t act_init = 0;
+
+ // A simple ADD operation: op1 ADD op2 = op3, with op2 using a const hardware buffer.
+ testAvailableSinceV1_2([&memory, &act_init](WrapperModel* model) {
+ auto op1 = model->addOperand(&kTypeTensorFloat);
+ auto op2 = model->addOperand(&kTypeTensorFloat);
+ auto act = model->addOperand(&kTypeInt32);
+ auto op3 = model->addOperand(&kTypeTensorFloat);
+ model->setOperandValueFromMemory(op2, &memory, 0, sizeof(float));
+ model->setOperandValue(act, &act_init, sizeof(act_init));
+ model->addOperation(ANEURALNETWORKS_ADD, {op1, op2, act}, {op3});
+ model->identifyInputsAndOutputs({op1}, {op3});
+ assert(model->isValid());
+ });
+
+ AHardwareBuffer_release(buffer);
+}
+
} // namespace compliance_test
diff --git a/nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp b/nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp
index afc19c3ee..101e2f707 100644
--- a/nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp
+++ b/nn/runtime/test/fuzzing/operation_signatures/BoundingBox.cpp
@@ -57,6 +57,10 @@ static void roiConstructor(Type, uint32_t rank, RandomOperation* op) {
} else {
op->outputs[0]->dimensions = {outBatch, outHeight, outWidth, outDepth};
}
+
+ if (op->opType == ANEURALNETWORKS_ROI_POOLING) {
+ setSameQuantization(op->outputs[0], op->inputs[0]);
+ }
}
template <typename T>
diff --git a/nn/runtime/test/generated/models/logistic_v1_2.model.cpp b/nn/runtime/test/generated/models/logistic_v1_2.model.cpp
index 3c5187861..170d881eb 100644
--- a/nn/runtime/test/generated/models/logistic_v1_2.model.cpp
+++ b/nn/runtime/test/generated/models/logistic_v1_2.model.cpp
@@ -252,7 +252,7 @@ void CreateModel_zero_sized_quant8(Model *model) {
OperandType type10(Type::BOOL, {});
OperandType type14(Type::TENSOR_QUANT8_ASYMM, {0, 2, 2, 1}, 0.1f, 128);
OperandType type15(Type::TENSOR_QUANT8_ASYMM, {1, 1, 1, 1}, 0.1f, 128);
- OperandType type16(Type::TENSOR_QUANT8_ASYMM, {0, 2, 2, 1}, 0.00390625f, 128);
+ OperandType type16(Type::TENSOR_QUANT8_ASYMM, {0, 2, 2, 1}, 0.00390625f, 0);
OperandType type17(Type::TENSOR_QUANT16_ASYMM, {1, 8}, 0.125f, 0);
OperandType type18(Type::TENSOR_QUANT16_ASYMM, {0, 4}, 0.125f, 0);
OperandType type19(Type::TENSOR_QUANT8_ASYMM, {1, 2}, 0.1f, 128);
@@ -597,7 +597,7 @@ void CreateModel_zero_sized_dynamic_output_shape_quant8(Model *model) {
OperandType type18(Type::TENSOR_QUANT16_ASYMM, {0, 4}, 0.125f, 0);
OperandType type19(Type::TENSOR_QUANT8_ASYMM, {1, 2}, 0.1f, 128);
OperandType type20(Type::TENSOR_QUANT8_ASYMM, {0}, 0.1f, 128);
- OperandType type29(Type::TENSOR_QUANT8_ASYMM, {0, 0, 0, 0}, 0.00390625f, 128);
+ OperandType type29(Type::TENSOR_QUANT8_ASYMM, {0, 0, 0, 0}, 0.00390625f, 0);
OperandType type5(Type::TENSOR_INT32, {0});
OperandType type7(Type::TENSOR_INT32, {1});
OperandType type8(Type::FLOAT32, {});
diff --git a/nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp b/nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp
index cd0f47a49..fc9419281 100644
--- a/nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp
+++ b/nn/runtime/test/generated/vts_models/logistic_v1_2.model.cpp
@@ -915,7 +915,7 @@ Model createTestModel_zero_sized_quant8() {
.dimensions = {0, 2, 2, 1},
.numberOfConsumers = 0,
.scale = 0.00390625f,
- .zeroPoint = 128,
+ .zeroPoint = 0,
.lifetime = OperandLifeTime::MODEL_OUTPUT,
.location = {.poolIndex = 0, .offset = 0, .length = 0},
}
@@ -1924,7 +1924,7 @@ Model createTestModel_zero_sized_dynamic_output_shape_quant8() {
.dimensions = {0, 0, 0, 0},
.numberOfConsumers = 0,
.scale = 0.00390625f,
- .zeroPoint = 128,
+ .zeroPoint = 0,
.lifetime = OperandLifeTime::MODEL_OUTPUT,
.location = {.poolIndex = 0, .offset = 0, .length = 0},
}
diff --git a/nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py b/nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py
index f8037b1c5..fe91a814d 100644
--- a/nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py
+++ b/nn/runtime/test/specs/V1_2/logistic_v1_2.mod.py
@@ -82,7 +82,7 @@ quant8 = DataTypeConverter().Identify({
tmp1: ("TENSOR_QUANT16_ASYMM", 0.125, 0),
i1: ("TENSOR_QUANT8_ASYMM", 0.1, 128),
zero_sized: ("TENSOR_QUANT8_ASYMM", 0.1, 128),
- o3: ("TENSOR_QUANT8_ASYMM", 1.0 / 256, 128)
+ o3: ("TENSOR_QUANT8_ASYMM", 1.0 / 256, 0)
})
# Create test case with dummy values.