diff options
Diffstat (limited to 'nn/runtime')
-rw-r--r-- | nn/runtime/Manager.cpp | 8 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.cpp | 27 | ||||
-rw-r--r-- | nn/runtime/test/TestCompliance.cpp | 32 |
3 files changed, 51 insertions, 16 deletions
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp index 310710e3c..634cd2aec 100644 --- a/nn/runtime/Manager.cpp +++ b/nn/runtime/Manager.cpp @@ -379,9 +379,9 @@ std::tuple<int, std::vector<OutputShape>, Timing> DriverPreparedModel::execute( const bool burstCompute = (burstController != nullptr); bool burstFallback = true; if (burstCompute) { - const bool compliant = compliantWithV1_0(request); + const bool compliant = compliantWithV1_2(request); if (compliant) { - V1_0::Request request10 = convertToV1_0(request); + V1_0::Request request12 = convertToV1_2(request); std::vector<intptr_t> memoryIds; memoryIds.reserve(localMemories.size()); for (const Memory* memory : localMemories) { @@ -390,9 +390,9 @@ std::tuple<int, std::vector<OutputShape>, Timing> DriverPreparedModel::execute( } VLOG(EXECUTION) << "Before ExecutionBurstController->compute() " - << SHOW_IF_DEBUG(toString(request10)); + << SHOW_IF_DEBUG(toString(request12)); std::tie(n, outputShapes, timing, burstFallback) = - burstController->compute(request10, measure, memoryIds); + burstController->compute(request12, measure, memoryIds); } } diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index 3ae950eac..33d290cfe 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -241,17 +241,16 @@ std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::execu return getResults(*callback); } - const bool compliant = compliantWithV1_0(request); - if (!compliant) { - LOG(ERROR) << "Could not handle execute or execute_1_2!"; - return failWithStatus(ErrorStatus::GENERAL_FAILURE); - } - const V1_0::Request request10 = convertToV1_0(request); - // version 1.2 HAL if (mPreparedModelV1_2 != nullptr) { + const bool compliant = compliantWithV1_2(request); + if (!compliant) { + LOG(ERROR) << "Could not handle execute_1_2!"; + return failWithStatus(ErrorStatus::GENERAL_FAILURE); + } + const V1_0::Request request12 = convertToV1_2(request); Return<V1_0::ErrorStatus> ret = - mPreparedModelV1_2->execute_1_2(request10, measure, callback); + mPreparedModelV1_2->execute_1_2(request12, measure, callback); if (ret.isDeadObject()) { LOG(ERROR) << "execute_1_2 failure: " << ret.description(); return failDeadObject(); @@ -271,6 +270,12 @@ std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::execu // version 1.0 HAL if (mPreparedModelV1_0 != nullptr) { + const bool compliant = compliantWithV1_0(request); + if (!compliant) { + LOG(ERROR) << "Could not handle execute!"; + return failWithStatus(ErrorStatus::GENERAL_FAILURE); + } + const V1_0::Request request10 = convertToV1_0(request); Return<V1_0::ErrorStatus> ret = mPreparedModelV1_0->execute(request10, callback); if (ret.isDeadObject()) { LOG(ERROR) << "execute failure: " << ret.description(); @@ -324,16 +329,16 @@ std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::execu // version 1.2 HAL if (mPreparedModelV1_2 != nullptr) { - const bool compliant = compliantWithV1_0(request); + const bool compliant = compliantWithV1_2(request); if (!compliant) { LOG(ERROR) << "Could not handle executeSynchronously!"; return kFailure; } - const V1_0::Request request10 = convertToV1_0(request); + const V1_0::Request request12 = convertToV1_2(request); std::tuple<int, std::vector<OutputShape>, Timing> result; Return<void> ret = mPreparedModelV1_2->executeSynchronously( - request10, measure, + request12, measure, [&result](V1_0::ErrorStatus error, const hidl_vec<OutputShape>& outputShapes, const Timing& timing) { result = getExecutionResult(convertToV1_3(error), outputShapes, timing); diff --git a/nn/runtime/test/TestCompliance.cpp b/nn/runtime/test/TestCompliance.cpp index 53bff038b..db5ab4d3e 100644 --- a/nn/runtime/test/TestCompliance.cpp +++ b/nn/runtime/test/TestCompliance.cpp @@ -18,6 +18,7 @@ #include "GeneratedTestUtils.h" #include "HalInterfaces.h" +#include "Memory.h" #include "MemoryUtils.h" #include "ModelBuilder.h" #include "TestNeuralNetworksWrapper.h" @@ -71,8 +72,14 @@ static void testAvailableSinceV1_0(const WrapperModel& wrapperModel) { ASSERT_TRUE(compliantWithV1_0(hidlModel)); } +static void testAvailableSinceV1_2(const Request& request) { + ASSERT_FALSE(compliantWithV1_0(request)); + ASSERT_TRUE(compliantWithV1_2(request)); +} + static void testAvailableSinceV1_3(const Request& request) { ASSERT_FALSE(compliantWithV1_0(request)); + ASSERT_FALSE(compliantWithV1_2(request)); } static const WrapperOperandType kTypeTensorFloat(WrapperType::TENSOR_FLOAT32, {1}); @@ -126,7 +133,7 @@ TEST_F(ComplianceTest, Rank0TensorTemporaryVariable) { testAvailableSinceV1_2(model); } -TEST_F(ComplianceTest, HardwareBuffer) { +TEST_F(ComplianceTest, HardwareBufferModel) { const size_t memorySize = 20; AHardwareBuffer_Desc desc{ .width = memorySize, @@ -157,6 +164,29 @@ TEST_F(ComplianceTest, HardwareBuffer) { AHardwareBuffer_release(buffer); } +TEST_F(ComplianceTest, HardwareBufferRequest) { + const auto [n, ahwb] = MemoryRuntimeAHWB::create(1024); + ASSERT_EQ(n, ANEURALNETWORKS_NO_ERROR); + Request::MemoryPool sharedMemoryPool, ahwbMemoryPool = ahwb->getMemoryPool(); + sharedMemoryPool.hidlMemory(allocateSharedMemory(1024)); + ASSERT_TRUE(sharedMemoryPool.hidlMemory().valid()); + ASSERT_TRUE(ahwbMemoryPool.hidlMemory().valid()); + + // AHardwareBuffer as input. + testAvailableSinceV1_2(Request{ + .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}}, + .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}}, + .pools = {ahwbMemoryPool, sharedMemoryPool}, + }); + + // AHardwareBuffer as output. + testAvailableSinceV1_2(Request{ + .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}}, + .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}}, + .pools = {sharedMemoryPool, ahwbMemoryPool}, + }); +} + TEST_F(ComplianceTest, DeviceMemory) { Request::MemoryPool sharedMemoryPool, deviceMemoryPool; sharedMemoryPool.hidlMemory(allocateSharedMemory(1024)); |