diff options
-rw-r--r-- | nn/runtime/ExecutionPlan.cpp | 1 | ||||
-rw-r--r-- | nn/runtime/test/Android.bp | 2 | ||||
-rw-r--r-- | nn/runtime/test/TestFailingDriver.cpp (renamed from nn/runtime/test/TestControlFlowExecution.cpp) | 38 |
3 files changed, 40 insertions, 1 deletions
diff --git a/nn/runtime/ExecutionPlan.cpp b/nn/runtime/ExecutionPlan.cpp index 0557f95d4..5e618825f 100644 --- a/nn/runtime/ExecutionPlan.cpp +++ b/nn/runtime/ExecutionPlan.cpp @@ -1108,6 +1108,7 @@ int ExecutionPlan::next(std::shared_ptr<Controller> controller, if (burstController != nullptr && controller->mBurstBuilder != nullptr) { *burstController = controller->mBurstBuilder->getControllerAt(0); } + controller->mFallbackNextStepIndex = 0; controller->mNextStepIndex = 1; return ANEURALNETWORKS_NO_ERROR; } diff --git a/nn/runtime/test/Android.bp b/nn/runtime/test/Android.bp index aed0c4e60..65d9de688 100644 --- a/nn/runtime/test/Android.bp +++ b/nn/runtime/test/Android.bp @@ -121,9 +121,9 @@ cc_defaults { // "TestOpenmpSettings.cpp", "TestCompilationCaching.cpp", "TestCompliance.cpp", - "TestControlFlowExecution.cpp", "TestExecution.cpp", "TestExtensions.cpp", + "TestFailingDriver.cpp", "TestIntrospectionControl.cpp", "TestMemoryDomain.cpp", "TestMemoryInternal.cpp", diff --git a/nn/runtime/test/TestControlFlowExecution.cpp b/nn/runtime/test/TestFailingDriver.cpp index d345c6f13..7d41ace20 100644 --- a/nn/runtime/test/TestControlFlowExecution.cpp +++ b/nn/runtime/test/TestFailingDriver.cpp @@ -172,5 +172,43 @@ TEST_F(FailingDriverTest, FailAfterInterpretedWhile) { ASSERT_EQ(fSqrt[1], 5); } +// Regression test for b/155923033. +TEST_F(FailingDriverTest, SimplePlan) { + // Model: + // output0 = SQRT(input0) + // + // This results in a SIMPLE execution plan. When FailingTestDriver fails, + // partial CPU fallback should complete the execution. + + WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2}); + + WrapperModel model; + { + uint32_t fInput = model.addOperand(&floatType); + uint32_t fSqrt = model.addOperand(&floatType); + model.addOperation(ANEURALNETWORKS_SQRT, {fInput}, {fSqrt}); + model.identifyInputsAndOutputs({fInput}, {fSqrt}); + ASSERT_TRUE(model.isValid()); + ASSERT_EQ(model.finish(), Result::NO_ERROR); + } + + WrapperCompilation compilation(&model); + ASSERT_EQ(compilation.finish(), Result::NO_ERROR); + + const CompilationBuilder* compilationBuilder = + reinterpret_cast<CompilationBuilder*>(compilation.getHandle()); + const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan(); + ASSERT_TRUE(plan.isSimple()); + + WrapperExecution execution(&compilation); + const float fInput[] = {12 * 12, 5 * 5}; + float fSqrt[] = {0, 0}; + ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR); + ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR); + ASSERT_EQ(execution.compute(), Result::NO_ERROR); + ASSERT_EQ(fSqrt[0], 12); + ASSERT_EQ(fSqrt[1], 5); +} + } // namespace } // namespace android::nn |