diff options
author | Slava Shklyaev <slavash@google.com> | 2020-05-11 14:14:17 +0100 |
---|---|---|
committer | Slava Shklyaev <slavash@google.com> | 2020-05-13 10:38:02 +0100 |
commit | c4f49a93f472cf570dde8a623bd40a5ee9bcce00 (patch) | |
tree | ab893d219435be6f5ae337e5f8c3457e069f68ea /nn/runtime/test/TestFailingDriver.cpp | |
parent | 6a8e397336cb494c488951183cc2d591deb8fe13 (diff) | |
download | ml-c4f49a93f472cf570dde8a623bd40a5ee9bcce00.tar.gz |
Fix CPU fallback bug with SIMPLE execution plan
Fix: 155923033
Test: m
Change-Id: Ia701c6097695fd452c408d4423998c55d823a52f
Diffstat (limited to 'nn/runtime/test/TestFailingDriver.cpp')
-rw-r--r-- | nn/runtime/test/TestFailingDriver.cpp | 214 |
1 files changed, 214 insertions, 0 deletions
diff --git a/nn/runtime/test/TestFailingDriver.cpp b/nn/runtime/test/TestFailingDriver.cpp new file mode 100644 index 000000000..7d41ace20 --- /dev/null +++ b/nn/runtime/test/TestFailingDriver.cpp @@ -0,0 +1,214 @@ +/* + * Copyright (C) 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +#include "CompilationBuilder.h" +#include "ExecutionPlan.h" +#include "Manager.h" +#include "SampleDriverPartial.h" +#include "TestNeuralNetworksWrapper.h" + +namespace android::nn { +namespace { + +using namespace hal; +using sample_driver::SampleDriverPartial; +using Result = test_wrapper::Result; +using WrapperOperandType = test_wrapper::OperandType; +using WrapperCompilation = test_wrapper::Compilation; +using WrapperExecution = test_wrapper::Execution; +using WrapperType = test_wrapper::Type; +using WrapperModel = test_wrapper::Model; + +class EmptyOperationResolver : public IOperationResolver { + public: + const OperationRegistration* findOperation(OperationType) const override { return nullptr; } +}; + +const char* kTestDriverName = "nnapi-test-sqrt-failing"; + +// A driver that only supports SQRT and fails during execution. +class FailingTestDriver : public SampleDriverPartial { + public: + // EmptyOperationResolver causes execution to fail. + FailingTestDriver() : SampleDriverPartial(kTestDriverName, &mEmptyOperationResolver) {} + + Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override { + cb(V1_3::ErrorStatus::NONE, + {.operandPerformance = {{.type = OperandType::TENSOR_FLOAT32, + .info = {.execTime = 0.1, // Faster than CPU. + .powerUsage = 0.1}}}}); + return Void(); + } + + private: + std::vector<bool> getSupportedOperationsImpl(const Model& model) const override { + std::vector<bool> supported(model.main.operations.size()); + std::transform( + model.main.operations.begin(), model.main.operations.end(), supported.begin(), + [](const Operation& operation) { return operation.type == OperationType::SQRT; }); + return supported; + } + + const EmptyOperationResolver mEmptyOperationResolver; +}; + +class FailingDriverTest : public ::testing::Test { + virtual void SetUp() { + DeviceManager* deviceManager = DeviceManager::get(); + if (deviceManager->getUseCpuOnly() || + !DeviceManager::partitioningAllowsFallback(deviceManager->getPartitioning())) { + GTEST_SKIP(); + } + mTestDevice = + DeviceManager::forTest_makeDriverDevice(kTestDriverName, new FailingTestDriver()); + deviceManager->forTest_setDevices({ + mTestDevice, + DeviceManager::getCpuDevice(), + }); + } + + virtual void TearDown() { DeviceManager::get()->forTest_reInitializeDeviceList(); } + + protected: + std::shared_ptr<Device> mTestDevice; +}; + +// Regression test for b/152623150. +TEST_F(FailingDriverTest, FailAfterInterpretedWhile) { + // Model: + // f = input0 + // b = input1 + // while CAST(b): # Identity cast. + // f = CAST(f) + // # FailingTestDriver fails here. When partial CPU fallback happens, + // # it should not loop forever. + // output0 = SQRT(f) + + WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2}); + WrapperOperandType boolType(WrapperType::TENSOR_BOOL8, {1}); + + WrapperModel conditionModel; + { + uint32_t f = conditionModel.addOperand(&floatType); + uint32_t b = conditionModel.addOperand(&boolType); + uint32_t out = conditionModel.addOperand(&boolType); + conditionModel.addOperation(ANEURALNETWORKS_CAST, {b}, {out}); + conditionModel.identifyInputsAndOutputs({f, b}, {out}); + ASSERT_EQ(conditionModel.finish(), Result::NO_ERROR); + ASSERT_TRUE(conditionModel.isValid()); + } + + WrapperModel bodyModel; + { + uint32_t f = bodyModel.addOperand(&floatType); + uint32_t b = bodyModel.addOperand(&boolType); + uint32_t out = bodyModel.addOperand(&floatType); + bodyModel.addOperation(ANEURALNETWORKS_CAST, {f}, {out}); + bodyModel.identifyInputsAndOutputs({f, b}, {out}); + ASSERT_EQ(bodyModel.finish(), Result::NO_ERROR); + ASSERT_TRUE(bodyModel.isValid()); + } + + WrapperModel model; + { + uint32_t fInput = model.addOperand(&floatType); + uint32_t bInput = model.addOperand(&boolType); + uint32_t fTmp = model.addOperand(&floatType); + uint32_t fSqrt = model.addOperand(&floatType); + uint32_t cond = model.addModelOperand(&conditionModel); + uint32_t body = model.addModelOperand(&bodyModel); + model.addOperation(ANEURALNETWORKS_WHILE, {cond, body, fInput, bInput}, {fTmp}); + model.addOperation(ANEURALNETWORKS_SQRT, {fTmp}, {fSqrt}); + model.identifyInputsAndOutputs({fInput, bInput}, {fSqrt}); + ASSERT_TRUE(model.isValid()); + ASSERT_EQ(model.finish(), Result::NO_ERROR); + } + + WrapperCompilation compilation(&model); + ASSERT_EQ(compilation.finish(), Result::NO_ERROR); + + const CompilationBuilder* compilationBuilder = + reinterpret_cast<CompilationBuilder*>(compilation.getHandle()); + const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan(); + const std::vector<std::shared_ptr<LogicalStep>>& steps = plan.forTest_compoundGetSteps(); + ASSERT_EQ(steps.size(), 6u); + ASSERT_TRUE(steps[0]->isWhile()); + ASSERT_TRUE(steps[1]->isExecution()); + ASSERT_EQ(steps[1]->executionStep()->getDevice(), DeviceManager::getCpuDevice()); + ASSERT_TRUE(steps[2]->isGoto()); + ASSERT_TRUE(steps[3]->isExecution()); + ASSERT_EQ(steps[3]->executionStep()->getDevice(), DeviceManager::getCpuDevice()); + ASSERT_TRUE(steps[4]->isGoto()); + ASSERT_TRUE(steps[5]->isExecution()); + ASSERT_EQ(steps[5]->executionStep()->getDevice(), mTestDevice); + + WrapperExecution execution(&compilation); + const float fInput[] = {12 * 12, 5 * 5}; + const bool8 bInput = false; + float fSqrt[] = {0, 0}; + ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR); + ASSERT_EQ(execution.setInput(1, &bInput), Result::NO_ERROR); + ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR); + ASSERT_EQ(execution.compute(), Result::NO_ERROR); + ASSERT_EQ(fSqrt[0], 12); + ASSERT_EQ(fSqrt[1], 5); +} + +// Regression test for b/155923033. +TEST_F(FailingDriverTest, SimplePlan) { + // Model: + // output0 = SQRT(input0) + // + // This results in a SIMPLE execution plan. When FailingTestDriver fails, + // partial CPU fallback should complete the execution. + + WrapperOperandType floatType(WrapperType::TENSOR_FLOAT32, {2}); + + WrapperModel model; + { + uint32_t fInput = model.addOperand(&floatType); + uint32_t fSqrt = model.addOperand(&floatType); + model.addOperation(ANEURALNETWORKS_SQRT, {fInput}, {fSqrt}); + model.identifyInputsAndOutputs({fInput}, {fSqrt}); + ASSERT_TRUE(model.isValid()); + ASSERT_EQ(model.finish(), Result::NO_ERROR); + } + + WrapperCompilation compilation(&model); + ASSERT_EQ(compilation.finish(), Result::NO_ERROR); + + const CompilationBuilder* compilationBuilder = + reinterpret_cast<CompilationBuilder*>(compilation.getHandle()); + const ExecutionPlan& plan = compilationBuilder->forTest_getExecutionPlan(); + ASSERT_TRUE(plan.isSimple()); + + WrapperExecution execution(&compilation); + const float fInput[] = {12 * 12, 5 * 5}; + float fSqrt[] = {0, 0}; + ASSERT_EQ(execution.setInput(0, &fInput), Result::NO_ERROR); + ASSERT_EQ(execution.setOutput(0, &fSqrt), Result::NO_ERROR); + ASSERT_EQ(execution.compute(), Result::NO_ERROR); + ASSERT_EQ(fSqrt[0], 12); + ASSERT_EQ(fSqrt[1], 5); +} + +} // namespace +} // namespace android::nn |