summaryrefslogtreecommitdiff
path: root/nn/runtime/test
diff options
context:
space:
mode:
authorSlava Shklyaev <slavash@google.com>2020-05-11 14:14:17 +0100
committerSlava Shklyaev <slavash@google.com>2020-05-13 10:38:02 +0100
commitc4f49a93f472cf570dde8a623bd40a5ee9bcce00 (patch)
treeab893d219435be6f5ae337e5f8c3457e069f68ea /nn/runtime/test
parent6a8e397336cb494c488951183cc2d591deb8fe13 (diff)
downloadml-c4f49a93f472cf570dde8a623bd40a5ee9bcce00.tar.gz
Fix CPU fallback bug with SIMPLE execution plan
Fix: 155923033 Test: m Change-Id: Ia701c6097695fd452c408d4423998c55d823a52f
Diffstat (limited to 'nn/runtime/test')
-rw-r--r--nn/runtime/test/Android.bp2
-rw-r--r--nn/runtime/test/TestFailingDriver.cpp (renamed from nn/runtime/test/TestControlFlowExecution.cpp)38
2 files changed, 39 insertions, 1 deletions
diff --git a/nn/runtime/test/Android.bp b/nn/runtime/test/Android.bp
index aed0c4e60..65d9de688 100644
--- a/nn/runtime/test/Android.bp
+++ b/nn/runtime/test/Android.bp
@@ -121,9 +121,9 @@ cc_defaults {
// "TestOpenmpSettings.cpp",
"TestCompilationCaching.cpp",
"TestCompliance.cpp",
- "TestControlFlowExecution.cpp",
"TestExecution.cpp",
"TestExtensions.cpp",
+ "TestFailingDriver.cpp",
"TestIntrospectionControl.cpp",
"TestMemoryDomain.cpp",
"TestMemoryInternal.cpp",
diff --git a/nn/runtime/test/TestControlFlowExecution.cpp b/nn/runtime/test/TestFailingDriver.cpp
index d345c6f13..7d41ace20 100644
--- a/nn/runtime/test/TestControlFlowExecution.cpp
+++ b/nn/runtime/test/TestFailingDriver.cpp
@@ -172,5 +172,43 @@ TEST_F(FailingDriverTest, FailAfterInterpretedWhile) {
ASSERT_EQ(fSqrt[1], 5);
}
+// Regression test for b/155923033.
+TEST_F(FailingDriverTest, SimplePlan) {
+ // Model:
+ // output0 = SQRT(input0)
+ //
+ // This results in a SIMPLE execution plan. When FailingTestDriver fails,
+ // partial CPU fallback should complete the execution.
+
+ WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2});
+
+ WrapperModel model;
+ {
+ uint32_t fInput = model.addOperand(&floatType);
+ uint32_t fSqrt = model.addOperand(&floatType);
+ model.addOperation(ANEURALNETWORKS_SQRT, {fInput}, {fSqrt});
+ model.identifyInputsAndOutputs({fInput}, {fSqrt});
+ ASSERT_TRUE(model.isValid());
+ ASSERT_EQ(model.finish(), Result::NO_ERROR);
+ }
+
+ WrapperCompilation compilation(&model);
+ ASSERT_EQ(compilation.finish(), Result::NO_ERROR);
+
+ const CompilationBuilder* compilationBuilder =
+ reinterpret_cast<CompilationBuilder*>(compilation.getHandle());
+ const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan();
+ ASSERT_TRUE(plan.isSimple());
+
+ WrapperExecution execution(&compilation);
+ const float fInput[] = {12 * 12, 5 * 5};
+ float fSqrt[] = {0, 0};
+ ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR);
+ ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR);
+ ASSERT_EQ(execution.compute(), Result::NO_ERROR);
+ ASSERT_EQ(fSqrt[0], 12);
+ ASSERT_EQ(fSqrt[1], 5);
+}
+
} // namespace
} // namespace android::nn