summaryrefslogtreecommitdiff
path: root/nn/runtime/test/TestExecution.cpp
diff options
context:
space:
mode:
authorMiao Wang <miaowang@google.com>2019-01-30 14:26:19 -0800
committerMiao Wang <miaowang@google.com>2019-01-31 20:46:23 +0000
commit3e4c7c2482017f243594b18566035bf5f7911f91 (patch)
treef30efe31b9caaa37fb6378c9c55f5c336032a9f7 /nn/runtime/test/TestExecution.cpp
parent67812c2638641930aa84fd5ada8cace2b728dcd3 (diff)
downloadml-3e4c7c2482017f243594b18566035bf5f7911f91.tar.gz
No CPU fallback will be provided when using introspection API
ANeuralNetworksCompilation_createForDevices. Bug: 120796109 Bug: 120443043 Test: mm Test: NeuralNetworksTest_static Change-Id: I5088caac03cabde63a5447934364172882e6a16c Merged-In: I5088caac03cabde63a5447934364172882e6a16c (cherry picked from commit ad8b3da238794b4465c0bef60a60e1d731e22045)
Diffstat (limited to 'nn/runtime/test/TestExecution.cpp')
-rw-r--r--nn/runtime/test/TestExecution.cpp109
1 files changed, 91 insertions, 18 deletions
diff --git a/nn/runtime/test/TestExecution.cpp b/nn/runtime/test/TestExecution.cpp
index a586455e2..a8576cc50 100644
--- a/nn/runtime/test/TestExecution.cpp
+++ b/nn/runtime/test/TestExecution.cpp
@@ -327,21 +327,69 @@ public:
}
};
-template<class DriverClass>
-class ExecutionTestTemplate :
- public ::testing::TestWithParam<std::tuple<ErrorStatus, Result>> {
-public:
- ExecutionTestTemplate() :
- kName(toString(std::get<0>(GetParam()))),
- kForceErrorStatus(std::get<0>(GetParam())),
- kExpectResult(std::get<1>(GetParam())),
- mModel(makeModel()),
- mCompilation(&mModel, kName, kForceErrorStatus) {}
-
-protected:
+// This class has roughly the same functionality as TestCompilation class.
+// The major difference is that Introspection API is used to select the device.
+template <typename DriverClass>
+class TestIntrospectionCompilation : public WrapperCompilation {
+ public:
+ TestIntrospectionCompilation(const WrapperModel* model, const std::string& deviceName) {
+ std::vector<ANeuralNetworksDevice*> mDevices;
+ uint32_t numDevices = 0;
+ EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
+ EXPECT_GE(numDevices, (uint32_t)1);
+
+ for (uint32_t i = 0; i < numDevices; i++) {
+ ANeuralNetworksDevice* device = nullptr;
+ EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
+ const char* buffer = nullptr;
+ int result = ANeuralNetworksDevice_getName(device, &buffer);
+ if (result == ANEURALNETWORKS_NO_ERROR && deviceName.compare(buffer) == 0) {
+ mDevices.push_back(device);
+ }
+ }
+ // In CPU only mode, DeviceManager::getDrivers() will not be able to
+ // provide the actual device list. We will not be able to find the test
+ // driver with specified deviceName.
+ if (!DeviceManager::get()->getUseCpuOnly()) {
+ EXPECT_EQ(mDevices.size(), (uint32_t)1);
+
+ int result = ANeuralNetworksCompilation_createForDevices(
+ model->getHandle(), mDevices.data(), mDevices.size(), &mCompilation);
+ EXPECT_EQ(result, ANEURALNETWORKS_NO_ERROR);
+ }
+ }
+};
+
+template <class DriverClass>
+class ExecutionTestTemplate
+ : public ::testing::TestWithParam<std::tuple<ErrorStatus, Result, bool>> {
+ public:
+ ExecutionTestTemplate()
+ : kName(toString(std::get<0>(GetParam()))),
+ kForceErrorStatus(std::get<0>(GetParam())),
+ kExpectResult(std::get<1>(GetParam())),
+ kUseIntrospectionAPI(std::get<2>(GetParam())),
+ mModel(makeModel()) {
+ if (kUseIntrospectionAPI) {
+ DeviceManager::get()->forTest_registerDevice(kName.c_str(),
+ new DriverClass(kName, kForceErrorStatus));
+ mCompilation = TestIntrospectionCompilation<DriverClass>(&mModel, kName);
+ } else {
+ mCompilation = TestCompilation<DriverClass>(&mModel, kName, kForceErrorStatus);
+ }
+ }
+
+ protected:
// Unit test method
void TestWait();
+ virtual void TearDown() {
+ // Reinitialize the device list since Introspection API path altered it.
+ if (kUseIntrospectionAPI) {
+ DeviceManager::get()->forTest_reInitializeDeviceList();
+ }
+ }
+
const std::string kName;
// Allow dummying up the error status for execution. If
@@ -355,8 +403,11 @@ protected:
// equivalent of kForceErrorStatus.)
const Result kExpectResult;
+ // Whether mCompilation is created via Introspection API or not.
+ const bool kUseIntrospectionAPI;
+
WrapperModel mModel;
- TestCompilation<DriverClass> mCompilation;
+ WrapperCompilation mCompilation;
void setInputOutput(WrapperExecution* execution) {
mInputBuffer = kInputBuffer;
@@ -388,6 +439,11 @@ protected:
template<class DriverClass> void ExecutionTestTemplate<DriverClass>::TestWait() {
SCOPED_TRACE(kName);
+ // Skip Introspection API tests when CPU only flag is forced on.
+ if (kUseIntrospectionAPI && DeviceManager::get()->getUseCpuOnly()) {
+ GTEST_SKIP();
+ }
+
ASSERT_EQ(mCompilation.finish(), Result::NO_ERROR);
{
@@ -414,11 +470,15 @@ template<class DriverClass> void ExecutionTestTemplate<DriverClass>::TestWait()
}
auto kTestValues = ::testing::Values(
- std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR),
- std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE),
- std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED),
- std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE),
- std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA));
+ std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ false),
+ std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE,
+ /* kUseIntrospectionAPI */ false),
+ std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED,
+ /* kUseIntrospectionAPI */ false),
+ std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE,
+ /* kUseIntrospectionAPI */ false),
+ std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA,
+ /* kUseIntrospectionAPI */ false));
class ExecutionTest12 : public ExecutionTestTemplate<TestDriver12> {};
TEST_P(ExecutionTest12, Wait) {
@@ -438,5 +498,18 @@ TEST_P(ExecutionTest10, Wait) {
}
INSTANTIATE_TEST_CASE_P(Flavor, ExecutionTest10, kTestValues);
+auto kIntrospectionTestValues = ::testing::Values(
+ std::make_tuple(ErrorStatus::NONE, Result::NO_ERROR, /* kUseIntrospectionAPI */ true),
+ std::make_tuple(ErrorStatus::DEVICE_UNAVAILABLE, Result::UNAVAILABLE_DEVICE,
+ /* kUseIntrospectionAPI */ true),
+ std::make_tuple(ErrorStatus::GENERAL_FAILURE, Result::OP_FAILED,
+ /* kUseIntrospectionAPI */ true),
+ std::make_tuple(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, Result::OUTPUT_INSUFFICIENT_SIZE,
+ /* kUseIntrospectionAPI */ true),
+ std::make_tuple(ErrorStatus::INVALID_ARGUMENT, Result::BAD_DATA,
+ /* kUseIntrospectionAPI */ true));
+
+INSTANTIATE_TEST_CASE_P(IntrospectionFlavor, ExecutionTest12, kIntrospectionTestValues);
+
} // namespace
} // namespace android