diff options
Diffstat (limited to 'test/opt/dataflow.cpp')
-rw-r--r-- | test/opt/dataflow.cpp | 225 |
1 files changed, 225 insertions, 0 deletions
diff --git a/test/opt/dataflow.cpp b/test/opt/dataflow.cpp new file mode 100644 index 00000000..4742015a --- /dev/null +++ b/test/opt/dataflow.cpp @@ -0,0 +1,225 @@ +// Copyright (c) 2021 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/dataflow.h" + +#include <map> +#include <set> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "opt/function_utils.h" +#include "source/opt/build_module.h" + +namespace spvtools { +namespace opt { +namespace { + +using DataFlowTest = ::testing::Test; + +// Simple analyses for testing: + +// Stores the result IDs of visited instructions in visit order. +struct VisitOrder : public ForwardDataFlowAnalysis { + std::vector<uint32_t> visited_result_ids; + + VisitOrder(IRContext& context, LabelPosition label_position) + : ForwardDataFlowAnalysis(context, label_position) {} + + VisitResult Visit(Instruction* inst) override { + if (inst->HasResultId()) { + visited_result_ids.push_back(inst->result_id()); + } + return DataFlowAnalysis::VisitResult::kResultFixed; + } +}; + +// For each block, stores the set of blocks it can be preceded by. +// For example, with the following CFG: +// V-----------. +// -> 11 -> 12 -> 13 -> 15 +// \-> 14 ---^ +// +// The answer is: +// 11: 11, 12, 13 +// 12: 11, 12, 13 +// 13: 11, 12, 13 +// 14: 11, 12, 13 +// 15: 11, 12, 13, 14 +struct BackwardReachability : public ForwardDataFlowAnalysis { + std::map<uint32_t, std::set<uint32_t>> reachable_from; + + BackwardReachability(IRContext& context) + : ForwardDataFlowAnalysis( + context, ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly) {} + + VisitResult Visit(Instruction* inst) override { + // Conditional branches can be enqueued from labels, so skip them. + if (inst->opcode() != SpvOpLabel) + return DataFlowAnalysis::VisitResult::kResultFixed; + uint32_t id = inst->result_id(); + VisitResult ret = DataFlowAnalysis::VisitResult::kResultFixed; + std::set<uint32_t>& precedents = reachable_from[id]; + for (uint32_t pred : context().cfg()->preds(id)) { + bool pred_inserted = precedents.insert(pred).second; + if (pred_inserted) { + ret = DataFlowAnalysis::VisitResult::kResultChanged; + } + for (uint32_t block : reachable_from[pred]) { + bool inserted = precedents.insert(block).second; + if (inserted) { + ret = DataFlowAnalysis::VisitResult::kResultChanged; + } + } + } + return ret; + } + + void InitializeWorklist(Function* function, + bool is_first_iteration) override { + // Since successor function is exact, only need one pass. + if (is_first_iteration) { + ForwardDataFlowAnalysis::InitializeWorklist(function, true); + } + } +}; + +TEST_F(DataFlowTest, ReversePostOrder) { + // Note: labels and IDs are intentionally out of order. + // + // CFG: (order of branches is from bottom to top) + // V-----------. + // -> 50 -> 40 -> 20 -> 60 -> 70 + // \-> 30 ---^ + + // DFS tree with RPO numbering: + // -> 50[0] -> 40[1] -> 20[2] 60[4] -> 70[5] + // \-> 30[3] ---^ + + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %6 = OpTypeBool + %5 = OpConstantTrue %6 + %2 = OpFunction %3 None %4 + %50 = OpLabel + %51 = OpUndef %6 + %52 = OpUndef %6 + OpBranch %40 + %70 = OpLabel + %69 = OpUndef %6 + OpReturn + %60 = OpLabel + %61 = OpUndef %6 + OpBranchConditional %5 %70 %40 + %30 = OpLabel + %29 = OpUndef %6 + OpBranch %60 + %20 = OpLabel + %21 = OpUndef %6 + OpBranch %60 + %40 = OpLabel + %39 = OpUndef %6 + OpBranchConditional %5 %30 %20 + OpFunctionEnd + )"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(context, nullptr); + + Function* function = spvtest::GetFunction(context->module(), 2); + + std::map<ForwardDataFlowAnalysis::LabelPosition, std::vector<uint32_t>> + expected_order; + expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly] = { + 50, 40, 20, 30, 60, 70, + }; + expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtBeginning] = { + 50, 51, 52, 40, 39, 20, 21, 30, 29, 60, 61, 70, 69, + }; + expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtEnd] = { + 51, 52, 50, 39, 40, 21, 20, 29, 30, 61, 60, 69, 70, + }; + expected_order[ForwardDataFlowAnalysis::LabelPosition::kNoLabels] = { + 51, 52, 39, 21, 29, 61, 69, + }; + + for (const auto& test_case : expected_order) { + VisitOrder analysis(*context, test_case.first); + analysis.Run(function); + EXPECT_EQ(test_case.second, analysis.visited_result_ids); + } +} + +TEST_F(DataFlowTest, BackwardReachability) { + // CFG: + // V-----------. + // -> 11 -> 12 -> 13 -> 15 + // \-> 14 ---^ + + const std::string text = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource GLSL 430 + %3 = OpTypeVoid + %4 = OpTypeFunction %3 + %6 = OpTypeBool + %5 = OpConstantTrue %6 + %2 = OpFunction %3 None %4 + %11 = OpLabel + OpBranch %12 + %12 = OpLabel + OpBranchConditional %5 %14 %13 + %13 = OpLabel + OpBranchConditional %5 %15 %11 + %14 = OpLabel + OpBranch %15 + %15 = OpLabel + OpReturn + OpFunctionEnd + )"; + + std::unique_ptr<IRContext> context = + BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, + SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); + ASSERT_NE(context, nullptr); + + Function* function = spvtest::GetFunction(context->module(), 2); + + BackwardReachability analysis(*context); + analysis.Run(function); + + std::map<uint32_t, std::set<uint32_t>> expected_result; + expected_result[11] = {11, 12, 13}; + expected_result[12] = {11, 12, 13}; + expected_result[13] = {11, 12, 13}; + expected_result[14] = {11, 12, 13}; + expected_result[15] = {11, 12, 13, 14}; + EXPECT_EQ(expected_result, analysis.reachable_from); +} + +} // namespace +} // namespace opt +} // namespace spvtools |