summaryrefslogtreecommitdiff
path: root/nn
diff options
context:
space:
mode:
authorXusong Wang <xusongw@google.com>2020-05-12 16:43:34 +0000
committerAndroid (Google) Code Review <android-gerrit@google.com>2020-05-12 16:43:34 +0000
commit9926f82774463e258dcf2189341a2d36dd56f100 (patch)
tree4e7787eab0244d4bb6f71ba0f17568965d00aa51 /nn
parentef9c23ee28a3f36e494b3e9fc2aeba1cc284ae3c (diff)
parentca8c1cba4c7ba6612e0471be620c899e27032f77 (diff)
downloadml-9926f82774463e258dcf2189341a2d36dd56f100.tar.gz
Merge "Avoid sending ahwb requests to 1.0 and 1.1 drivers." into rvc-dev
Diffstat (limited to 'nn')
-rw-r--r--nn/common/Utils.cpp44
-rw-r--r--nn/common/include/Utils.h2
-rw-r--r--nn/runtime/Manager.cpp8
-rw-r--r--nn/runtime/VersionedInterfaces.cpp27
-rw-r--r--nn/runtime/test/TestCompliance.cpp32
5 files changed, 88 insertions, 25 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp
index cd97ffa52..81e5cf1e1 100644
--- a/nn/common/Utils.cpp
+++ b/nn/common/Utils.cpp
@@ -21,6 +21,8 @@
#include <android-base/logging.h>
#include <android-base/properties.h>
#include <android-base/strings.h>
+#include <errno.h>
+#include <poll.h>
#include <sys/system_properties.h>
#include <algorithm>
@@ -32,9 +34,6 @@
#include <utility>
#include <vector>
-#include <errno.h>
-#include <poll.h>
-
#include "ControlFlow.h"
#include "NeuralNetworks.h"
#include "NeuralNetworksOEM.h"
@@ -3100,7 +3099,22 @@ bool compliantWithV1_0(const V1_0::Request& request) {
bool compliantWithV1_0(const V1_3::Request& request) {
return std::all_of(request.pools.begin(), request.pools.end(), [](const auto& pool) {
- return pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory;
+ if (pool.getDiscriminator() != V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) {
+ return false;
+ }
+ const auto& name = pool.hidlMemory().name();
+ return name == "ashmem" || name == "mmap_fd";
+ });
+}
+
+bool compliantWithV1_2(const V1_3::Request& request) {
+ return std::all_of(request.pools.begin(), request.pools.end(), [](const auto& pool) {
+ if (pool.getDiscriminator() != V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) {
+ return false;
+ }
+ const auto& name = pool.hidlMemory().name();
+ return name == "ashmem" || name == "mmap_fd" || name == "hardware_buffer_blob" ||
+ name == "hardware_buffer";
});
}
@@ -3123,17 +3137,29 @@ V1_0::Request convertToV1_0(const V1_0::Request& request) {
return request;
}
-V1_0::Request convertToV1_0(const V1_3::Request& request) {
- if (!compliantWithV1_0(request)) {
- LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
- << " from V1_3::Request to V1_0::Request";
- }
+static V1_0::Request uncheckedConvertToV1_0(const V1_3::Request& request) {
hidl_vec<hidl_memory> pools(request.pools.size());
std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
[](const auto& pool) { return convertToV1_0(pool); });
return {.inputs = request.inputs, .outputs = request.outputs, .pools = std::move(pools)};
}
+V1_0::Request convertToV1_0(const V1_3::Request& request) {
+ if (!compliantWithV1_0(request)) {
+ LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
+ << " from V1_3::Request to V1_0::Request of version 1.0";
+ }
+ return uncheckedConvertToV1_0(request);
+}
+
+V1_0::Request convertToV1_2(const V1_3::Request& request) {
+ if (!compliantWithV1_2(request)) {
+ LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
+ << " from V1_3::Request to V1_0::Request of version 1.2";
+ }
+ return uncheckedConvertToV1_0(request);
+}
+
V1_3::Request convertToV1_3(const V1_0::Request& request) {
hidl_vec<V1_3::Request::MemoryPool> pools(request.pools.size());
std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h
index 24e69211c..ca11c5ebc 100644
--- a/nn/common/include/Utils.h
+++ b/nn/common/include/Utils.h
@@ -530,9 +530,11 @@ hal::hidl_vec<hal::V1_3::Operand> convertToV1_3(const hal::hidl_vec<hal::V1_3::O
bool compliantWithV1_0(const hal::V1_0::Request& request);
bool compliantWithV1_0(const hal::V1_3::Request& request);
+bool compliantWithV1_2(const hal::V1_3::Request& request);
hal::V1_0::Request convertToV1_0(const hal::V1_0::Request& request);
hal::V1_0::Request convertToV1_0(const hal::V1_3::Request& request);
+hal::V1_0::Request convertToV1_2(const hal::V1_3::Request& request);
hal::V1_3::Request convertToV1_3(const hal::V1_0::Request& request);
hal::V1_3::Request convertToV1_3(const hal::V1_3::Request& request);
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index 310710e3c..634cd2aec 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -379,9 +379,9 @@ std::tuple<int, std::vector<OutputShape>, Timing> DriverPreparedModel::execute(
const bool burstCompute = (burstController != nullptr);
bool burstFallback = true;
if (burstCompute) {
- const bool compliant = compliantWithV1_0(request);
+ const bool compliant = compliantWithV1_2(request);
if (compliant) {
- V1_0::Request request10 = convertToV1_0(request);
+ V1_0::Request request12 = convertToV1_2(request);
std::vector<intptr_t> memoryIds;
memoryIds.reserve(localMemories.size());
for (const Memory* memory : localMemories) {
@@ -390,9 +390,9 @@ std::tuple<int, std::vector<OutputShape>, Timing> DriverPreparedModel::execute(
}
VLOG(EXECUTION) << "Before ExecutionBurstController->compute() "
- << SHOW_IF_DEBUG(toString(request10));
+ << SHOW_IF_DEBUG(toString(request12));
std::tie(n, outputShapes, timing, burstFallback) =
- burstController->compute(request10, measure, memoryIds);
+ burstController->compute(request12, measure, memoryIds);
}
}
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index 3ae950eac..33d290cfe 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -241,17 +241,16 @@ std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::execu
return getResults(*callback);
}
- const bool compliant = compliantWithV1_0(request);
- if (!compliant) {
- LOG(ERROR) << "Could not handle execute or execute_1_2!";
- return failWithStatus(ErrorStatus::GENERAL_FAILURE);
- }
- const V1_0::Request request10 = convertToV1_0(request);
-
// version 1.2 HAL
if (mPreparedModelV1_2 != nullptr) {
+ const bool compliant = compliantWithV1_2(request);
+ if (!compliant) {
+ LOG(ERROR) << "Could not handle execute_1_2!";
+ return failWithStatus(ErrorStatus::GENERAL_FAILURE);
+ }
+ const V1_0::Request request12 = convertToV1_2(request);
Return<V1_0::ErrorStatus> ret =
- mPreparedModelV1_2->execute_1_2(request10, measure, callback);
+ mPreparedModelV1_2->execute_1_2(request12, measure, callback);
if (ret.isDeadObject()) {
LOG(ERROR) << "execute_1_2 failure: " << ret.description();
return failDeadObject();
@@ -271,6 +270,12 @@ std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::execu
// version 1.0 HAL
if (mPreparedModelV1_0 != nullptr) {
+ const bool compliant = compliantWithV1_0(request);
+ if (!compliant) {
+ LOG(ERROR) << "Could not handle execute!";
+ return failWithStatus(ErrorStatus::GENERAL_FAILURE);
+ }
+ const V1_0::Request request10 = convertToV1_0(request);
Return<V1_0::ErrorStatus> ret = mPreparedModelV1_0->execute(request10, callback);
if (ret.isDeadObject()) {
LOG(ERROR) << "execute failure: " << ret.description();
@@ -324,16 +329,16 @@ std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::execu
// version 1.2 HAL
if (mPreparedModelV1_2 != nullptr) {
- const bool compliant = compliantWithV1_0(request);
+ const bool compliant = compliantWithV1_2(request);
if (!compliant) {
LOG(ERROR) << "Could not handle executeSynchronously!";
return kFailure;
}
- const V1_0::Request request10 = convertToV1_0(request);
+ const V1_0::Request request12 = convertToV1_2(request);
std::tuple<int, std::vector<OutputShape>, Timing> result;
Return<void> ret = mPreparedModelV1_2->executeSynchronously(
- request10, measure,
+ request12, measure,
[&result](V1_0::ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
const Timing& timing) {
result = getExecutionResult(convertToV1_3(error), outputShapes, timing);
diff --git a/nn/runtime/test/TestCompliance.cpp b/nn/runtime/test/TestCompliance.cpp
index 53bff038b..db5ab4d3e 100644
--- a/nn/runtime/test/TestCompliance.cpp
+++ b/nn/runtime/test/TestCompliance.cpp
@@ -18,6 +18,7 @@
#include "GeneratedTestUtils.h"
#include "HalInterfaces.h"
+#include "Memory.h"
#include "MemoryUtils.h"
#include "ModelBuilder.h"
#include "TestNeuralNetworksWrapper.h"
@@ -71,8 +72,14 @@ static void testAvailableSinceV1_0(const WrapperModel& wrapperModel) {
ASSERT_TRUE(compliantWithV1_0(hidlModel));
}
+static void testAvailableSinceV1_2(const Request& request) {
+ ASSERT_FALSE(compliantWithV1_0(request));
+ ASSERT_TRUE(compliantWithV1_2(request));
+}
+
static void testAvailableSinceV1_3(const Request& request) {
ASSERT_FALSE(compliantWithV1_0(request));
+ ASSERT_FALSE(compliantWithV1_2(request));
}
static const WrapperOperandType kTypeTensorFloat(WrapperType::TENSOR_FLOAT32, {1});
@@ -126,7 +133,7 @@ TEST_F(ComplianceTest, Rank0TensorTemporaryVariable) {
testAvailableSinceV1_2(model);
}
-TEST_F(ComplianceTest, HardwareBuffer) {
+TEST_F(ComplianceTest, HardwareBufferModel) {
const size_t memorySize = 20;
AHardwareBuffer_Desc desc{
.width = memorySize,
@@ -157,6 +164,29 @@ TEST_F(ComplianceTest, HardwareBuffer) {
AHardwareBuffer_release(buffer);
}
+TEST_F(ComplianceTest, HardwareBufferRequest) {
+ const auto [n, ahwb] = MemoryRuntimeAHWB::create(1024);
+ ASSERT_EQ(n, ANEURALNETWORKS_NO_ERROR);
+ Request::MemoryPool sharedMemoryPool, ahwbMemoryPool = ahwb->getMemoryPool();
+ sharedMemoryPool.hidlMemory(allocateSharedMemory(1024));
+ ASSERT_TRUE(sharedMemoryPool.hidlMemory().valid());
+ ASSERT_TRUE(ahwbMemoryPool.hidlMemory().valid());
+
+ // AHardwareBuffer as input.
+ testAvailableSinceV1_2(Request{
+ .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}},
+ .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}},
+ .pools = {ahwbMemoryPool, sharedMemoryPool},
+ });
+
+ // AHardwareBuffer as output.
+ testAvailableSinceV1_2(Request{
+ .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}},
+ .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}},
+ .pools = {sharedMemoryPool, ahwbMemoryPool},
+ });
+}
+
TEST_F(ComplianceTest, DeviceMemory) {
Request::MemoryPool sharedMemoryPool, deviceMemoryPool;
sharedMemoryPool.hidlMemory(allocateSharedMemory(1024));