diff options
author | Miao Wang <miaowang@google.com> | 2019-01-30 14:26:19 -0800 |
---|---|---|
committer | Miao Wang <miaowang@google.com> | 2019-01-31 00:37:07 -0800 |
commit | ad8b3da238794b4465c0bef60a60e1d731e22045 (patch) | |
tree | 35d2c436d88a70ca5423e32fddf0397201387732 /nn/runtime/test/TestExecution.cpp | |
parent | fddec3e0bd84bfa47339959848efc322b12dd358 (diff) | |
download | ml-ad8b3da238794b4465c0bef60a60e1d731e22045.tar.gz |
No CPU fallback will be provided when using introspection API
ANeuralNetworksCompilation_createForDevices.
Bug: 120796109
Bug: 120443043
Test: mm
Test: NeuralNetworksTest_static
Change-Id: I5088caac03cabde63a5447934364172882e6a16c
Diffstat (limited to 'nn/runtime/test/TestExecution.cpp')
-rw-r--r-- | nn/runtime/test/TestExecution.cpp | 109 |
1 files changed, 91 insertions, 18 deletions
diff --git a/nn/runtime/test/TestExecution.cpp b/nn/runtime/test/TestExecution.cpp index 92e5e585b..c4f15751c 100644 --- a/nn/runtime/test/TestExecution.cpp +++ b/nn/runtime/test/TestExecution.cpp @@ -335,21 +335,69 @@ public: } }; -template<class DriverClass> -class ExecutionTestTemplate : - public ::testing::TestWithParam<std::tuple<ErrorStatus, Result>> { -public: - ExecutionTestTemplate() : - kName(toString(std::get<0>(GetParam()))), - kForceErrorStatus(std::get<0>(GetParam())), - kExpectResult(std::get<1>(GetParam())), - mModel(makeModel()), - mCompilation(&mModel, kName, kForceErrorStatus) {} - -protected: +// This class has roughly the same functionality as TestCompilation class. +// The major difference is that Introspection API is used to select the device. +template <typename DriverClass> +class TestIntrospectionCompilation : public WrapperCompilation { + public: + TestIntrospectionCompilation(const WrapperModel* model, const std::string& deviceName) { + std::vector<ANeuralNetworksDevice*> mDevices; + uint32_t numDevices = 0; + EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR); + EXPECT_GE(numDevices, (uint32_t)1); + + for (uint32_t i = 0; i < numDevices; i++) { + ANeuralNetworksDevice* device = nullptr; + EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR); + const char* buffer = nullptr; + int result = ANeuralNetworksDevice_getName(device, &buffer); + if (result == ANEURALNETWORKS_NO_ERROR && deviceName.compare(buffer) == 0) { + mDevices.push_back(device); + } + } + // In CPU only mode, DeviceManager::getDrivers() will not be able to + // provide the actual device list. We will not be able to find the test + // driver with specified deviceName. + if (!DeviceManager::get()->getUseCpuOnly()) { + EXPECT_EQ(mDevices.size(), (uint32_t)1); + + int result = ANeuralNetworksCompilation_createForDevices( + model->getHandle(), mDevices.data(), mDevices.size(), &mCompilation); + EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR); + } + } +}; + +template <class DriverClass> +class ExecutionTestTemplate + : public ::testing::TestWithParam<std::tuple<ErrorStatus, Result, bool>> { + public: + ExecutionTestTemplate() + : kName(toString(std::get<0>(GetParam()))), + kForceErrorStatus(std::get<0>(GetParam())), + kExpectResult(std::get<1>(GetParam())), + kUseIntrospectionAPI(std::get<2>(GetParam())), + mModel(makeModel()) { + if (kUseIntrospectionAPI) { + DeviceManager::get()->forTest_registerDevice(kName.c_str(), + new DriverClass(kName, kForceErrorStatus)); + mCompilation = TestIntrospectionCompilation<DriverClass>(&mModel, kName); + } else { + mCompilation = TestCompilation<DriverClass>(&mModel, kName, kForceErrorStatus); + } + } + + protected: // Unit test method void TestWait(); + virtual void TearDown() { + // Reinitialize the device list since Introspection API path altered it. + if (kUseIntrospectionAPI) { + DeviceManager::get()->forTest_reInitializeDeviceList(); + } + } + const std::string kName; // Allow dummying up the error status for execution. If @@ -363,8 +411,11 @@ protected: // equivalent of kForceErrorStatus.) const Result kExpectResult; + // Whether mCompilation is created via Introspection API or not. + const bool kUseIntrospectionAPI; + WrapperModel mModel; - TestCompilation<DriverClass> mCompilation; + WrapperCompilation mCompilation; void setInputOutput(WrapperExecution* execution) { mInputBuffer = kInputBuffer; @@ -397,6 +448,11 @@ protected: template<class DriverClass> void ExecutionTestTemplate<DriverClass>::TestWait() { SCOPED_TRACE(kName); + // Skip Introspection API tests when CPU only flag is forced on. + if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) { + GTEST_SKIP(); + } + ASSERT_EQ(mCompilation.finish(), Result::NO_ERROR); { @@ -446,11 +502,15 @@ template<class DriverClass> void ExecutionTestTemplate<DriverClass>::TestWait() } auto kTestValues = ::testing::Values( - std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR), - std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE), - std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED), - std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE), - std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA)); + std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ false), + std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE, + /* kUseIntrospectionAPI */ false), + std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED, + /* kUseIntrospectionAPI */ false), + std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE, + /* kUseIntrospectionAPI */ false), + std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA, + /* kUseIntrospectionAPI */ false)); class ExecutionTest12 : public ExecutionTestTemplate<TestDriver12> {}; TEST_P(ExecutionTest12, Wait) { @@ -472,5 +532,18 @@ TEST_P(ExecutionTest10, Wait) { } INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest10, kTestValues); +auto kIntrospectionTestValues = ::testing::Values( + std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ true), + std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE, + /* kUseIntrospectionAPI */ true), + std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED, + /* kUseIntrospectionAPI */ true), + std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE, + /* kUseIntrospectionAPI */ true), + std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA, + /* kUseIntrospectionAPI */ true)); + +INSTANTIATE_TEST_CASE_P(IntrospectionFlavor, ExecutionTest12, kIntrospectionTestValues); + } // namespace } // namespace android |