aboutsummaryrefslogtreecommitdiff
path: root/source/opt/dead_branch_elim_pass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/opt/dead_branch_elim_pass.cpp')
-rw-r--r--source/opt/dead_branch_elim_pass.cpp418
1 files changed, 418 insertions, 0 deletions
diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp
new file mode 100644
index 00000000..0c57307b
--- /dev/null
+++ b/source/opt/dead_branch_elim_pass.cpp
@@ -0,0 +1,418 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dead_branch_elim_pass.h"
+
+#include "cfa.h"
+#include "iterator.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+
+const uint32_t kEntryPointFunctionIdInIdx = 1;
+const uint32_t kBranchCondConditionalIdInIdx = 0;
+const uint32_t kBranchCondTrueLabIdInIdx = 1;
+const uint32_t kBranchCondFalseLabIdInIdx = 2;
+const uint32_t kSelectionMergeMergeBlockIdInIdx = 0;
+const uint32_t kPhiVal0IdInIdx = 0;
+const uint32_t kPhiLab0IdInIdx = 1;
+const uint32_t kPhiVal1IdInIdx = 2;
+const uint32_t kLoopMergeMergeBlockIdInIdx = 0;
+const uint32_t kLoopMergeContinueBlockIdInIdx = 1;
+
+} // anonymous namespace
+
+uint32_t DeadBranchElimPass::MergeBlockIdIfAny(
+ const ir::BasicBlock& blk, uint32_t* cbid) const {
+ auto merge_ii = blk.cend();
+ --merge_ii;
+ uint32_t mbid = 0;
+ *cbid = 0;
+ if (merge_ii != blk.cbegin()) {
+ --merge_ii;
+ if (merge_ii->opcode() == SpvOpLoopMerge) {
+ mbid = merge_ii->GetSingleWordInOperand(kLoopMergeMergeBlockIdInIdx);
+ *cbid = merge_ii->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx);
+ }
+ else if (merge_ii->opcode() == SpvOpSelectionMerge) {
+ mbid = merge_ii->GetSingleWordInOperand(
+ kSelectionMergeMergeBlockIdInIdx);
+ }
+ }
+ return mbid;
+}
+
+void DeadBranchElimPass::ComputeStructuredSuccessors(ir::Function* func) {
+ // If header, make merge block first successor. If a loop header, make
+ // the second successor the continue target.
+ for (auto& blk : *func) {
+ uint32_t cbid;
+ uint32_t mbid = MergeBlockIdIfAny(blk, &cbid);
+ if (mbid != 0) {
+ block2structured_succs_[&blk].push_back(id2block_[mbid]);
+ if (cbid != 0)
+ block2structured_succs_[&blk].push_back(id2block_[cbid]);
+ }
+ // add true successors
+ blk.ForEachSuccessorLabel([&blk, this](uint32_t sbid) {
+ block2structured_succs_[&blk].push_back(id2block_[sbid]);
+ });
+ }
+}
+
+void DeadBranchElimPass::ComputeStructuredOrder(
+ ir::Function* func, std::list<ir::BasicBlock*>* order) {
+ // Compute structured successors and do DFS
+ ComputeStructuredSuccessors(func);
+ auto ignore_block = [](cbb_ptr) {};
+ auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
+ auto get_structured_successors = [this](const ir::BasicBlock* block) {
+ return &(block2structured_succs_[block]); };
+ // TODO(greg-lunarg): Get rid of const_cast by making moving const
+ // out of the cfa.h prototypes and into the invoking code.
+ auto post_order = [&](cbb_ptr b) {
+ order->push_front(const_cast<ir::BasicBlock*>(b)); };
+
+ spvtools::CFA<ir::BasicBlock>::DepthFirstTraversal(
+ &*func->begin(), get_structured_successors, ignore_block, post_order,
+ ignore_edge);
+}
+
+void DeadBranchElimPass::GetConstCondition(
+ uint32_t condId, bool* condVal, bool* condIsConst) {
+ ir::Instruction* cInst = def_use_mgr_->GetDef(condId);
+ switch (cInst->opcode()) {
+ case SpvOpConstantFalse: {
+ *condVal = false;
+ *condIsConst = true;
+ } break;
+ case SpvOpConstantTrue: {
+ *condVal = true;
+ *condIsConst = true;
+ } break;
+ case SpvOpLogicalNot: {
+ bool negVal;
+ (void)GetConstCondition(cInst->GetSingleWordInOperand(0),
+ &negVal, condIsConst);
+ if (*condIsConst)
+ *condVal = !negVal;
+ } break;
+ default: {
+ *condIsConst = false;
+ } break;
+ }
+}
+
+void DeadBranchElimPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) {
+ std::unique_ptr<ir::Instruction> newBranch(
+ new ir::Instruction(SpvOpBranch, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}}));
+ def_use_mgr_->AnalyzeInstDefUse(&*newBranch);
+ bp->AddInstruction(std::move(newBranch));
+}
+
+void DeadBranchElimPass::AddSelectionMerge(uint32_t labelId,
+ ir::BasicBlock* bp) {
+ std::unique_ptr<ir::Instruction> newMerge(
+ new ir::Instruction(SpvOpSelectionMerge, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {0}}}));
+ def_use_mgr_->AnalyzeInstDefUse(&*newMerge);
+ bp->AddInstruction(std::move(newMerge));
+}
+
+void DeadBranchElimPass::AddBranchConditional(uint32_t condId,
+ uint32_t trueLabId, uint32_t falseLabId, ir::BasicBlock* bp) {
+ std::unique_ptr<ir::Instruction> newBranchCond(
+ new ir::Instruction(SpvOpBranchConditional, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {condId}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {trueLabId}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {falseLabId}}}));
+ def_use_mgr_->AnalyzeInstDefUse(&*newBranchCond);
+ bp->AddInstruction(std::move(newBranchCond));
+}
+
+void DeadBranchElimPass::KillNamesAndDecorates(uint32_t id) {
+ // TODO(greg-lunarg): Remove id from any OpGroupDecorate and
+ // kill if no other operands.
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return;
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return;
+ std::list<ir::Instruction*> killList;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op == SpvOpName || IsDecorate(op))
+ killList.push_back(u.inst);
+ }
+ for (auto kip : killList)
+ def_use_mgr_->KillInst(kip);
+}
+
+void DeadBranchElimPass::KillNamesAndDecorates(ir::Instruction* inst) {
+ const uint32_t rId = inst->result_id();
+ if (rId == 0)
+ return;
+ KillNamesAndDecorates(rId);
+}
+
+void DeadBranchElimPass::KillAllInsts(ir::BasicBlock* bp) {
+ bp->ForEachInst([this](ir::Instruction* ip) {
+ KillNamesAndDecorates(ip);
+ def_use_mgr_->KillInst(ip);
+ });
+}
+
+bool DeadBranchElimPass::GetConstConditionalSelectionBranch(ir::BasicBlock* bp,
+ ir::Instruction** branchInst, ir::Instruction** mergeInst,
+ uint32_t *condId, bool *condVal) {
+ auto ii = bp->end();
+ --ii;
+ *branchInst = &*ii;
+ if ((*branchInst)->opcode() != SpvOpBranchConditional)
+ return false;
+ if (ii == bp->begin())
+ return false;
+ --ii;
+ *mergeInst = &*ii;
+ if ((*mergeInst)->opcode() != SpvOpSelectionMerge)
+ return false;
+ bool condIsConst;
+ *condId = (*branchInst)->GetSingleWordInOperand(
+ kBranchCondConditionalIdInIdx);
+ (void) GetConstCondition(*condId, condVal, &condIsConst);
+ return condIsConst;
+}
+
+bool DeadBranchElimPass::HasNonPhiRef(uint32_t labelId) {
+ analysis::UseList* uses = def_use_mgr_->GetUses(labelId);
+ if (uses == nullptr)
+ return false;
+ for (auto u : *uses)
+ if (u.inst->opcode() != SpvOpPhi)
+ return true;
+ return false;
+}
+
+bool DeadBranchElimPass::EliminateDeadBranches(ir::Function* func) {
+ // Traverse blocks in structured order
+ std::list<ir::BasicBlock*> structuredOrder;
+ ComputeStructuredOrder(func, &structuredOrder);
+ std::unordered_set<ir::BasicBlock*> elimBlocks;
+ bool modified = false;
+ for (auto bi = structuredOrder.begin(); bi != structuredOrder.end(); ++bi) {
+ // Skip blocks that are already in the elimination set
+ if (elimBlocks.find(*bi) != elimBlocks.end())
+ continue;
+ // Skip blocks that don't have constant conditional branch preceded
+ // by OpSelectionMerge
+ ir::Instruction* br;
+ ir::Instruction* mergeInst;
+ uint32_t condId;
+ bool condVal;
+ if (!GetConstConditionalSelectionBranch(*bi, &br, &mergeInst, &condId,
+ &condVal))
+ continue;
+
+ // Replace conditional branch with unconditional branch
+ const uint32_t trueLabId =
+ br->GetSingleWordInOperand(kBranchCondTrueLabIdInIdx);
+ const uint32_t falseLabId =
+ br->GetSingleWordInOperand(kBranchCondFalseLabIdInIdx);
+ const uint32_t mergeLabId =
+ mergeInst->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx);
+ const uint32_t liveLabId = condVal == true ? trueLabId : falseLabId;
+ const uint32_t deadLabId = condVal == true ? falseLabId : trueLabId;
+ AddBranch(liveLabId, *bi);
+ def_use_mgr_->KillInst(br);
+ def_use_mgr_->KillInst(mergeInst);
+
+ // Iterate to merge block adding dead blocks to elimination set
+ auto dbi = bi;
+ ++dbi;
+ uint32_t dLabId = (*dbi)->id();
+ while (dLabId != mergeLabId) {
+ if (!HasNonPhiRef(dLabId)) {
+ // Kill use/def for all instructions and mark block for elimination
+ KillAllInsts(*dbi);
+ elimBlocks.insert(*dbi);
+ }
+ ++dbi;
+ dLabId = (*dbi)->id();
+ }
+
+ // Process phi instructions in merge block.
+ // elimBlocks are now blocks which cannot precede merge block. Also,
+ // if eliminated branch is to merge label, remember the conditional block
+ // also cannot precede merge block.
+ uint32_t deadCondLabId = 0;
+ if (deadLabId == mergeLabId)
+ deadCondLabId = (*bi)->id();
+ (*dbi)->ForEachPhiInst([&elimBlocks, &deadCondLabId, this](
+ ir::Instruction* phiInst) {
+ const uint32_t phiLabId0 =
+ phiInst->GetSingleWordInOperand(kPhiLab0IdInIdx);
+ const bool useFirst =
+ elimBlocks.find(id2block_[phiLabId0]) == elimBlocks.end() &&
+ phiLabId0 != deadCondLabId;
+ const uint32_t phiValIdx =
+ useFirst ? kPhiVal0IdInIdx : kPhiVal1IdInIdx;
+ const uint32_t replId = phiInst->GetSingleWordInOperand(phiValIdx);
+ const uint32_t phiId = phiInst->result_id();
+ KillNamesAndDecorates(phiId);
+ (void)def_use_mgr_->ReplaceAllUsesWith(phiId, replId);
+ def_use_mgr_->KillInst(phiInst);
+ });
+
+ // If merge block has no predecessors, replace the new branch with
+ // a MergeSelection/BranchCondition using the original constant condition
+ // and the mergeblock as the false branch. This is done so the merge block
+ // is not orphaned, which could cause invalid control flow in certain case.
+ // TODO(greg-lunarg): Do this only in cases where invalid code is caused.
+ if (!HasNonPhiRef(mergeLabId)) {
+ auto eii = (*bi)->end();
+ --eii;
+ ir::Instruction* nbr = &*eii;
+ AddSelectionMerge(mergeLabId, *bi);
+ if (condVal == true)
+ AddBranchConditional(condId, liveLabId, mergeLabId, *bi);
+ else
+ AddBranchConditional(condId, mergeLabId, liveLabId, *bi);
+ def_use_mgr_->KillInst(nbr);
+ }
+ modified = true;
+ }
+
+ // Erase dead blocks
+ for (auto ebi = func->begin(); ebi != func->end(); )
+ if (elimBlocks.find(&*ebi) != elimBlocks.end())
+ ebi = ebi.Erase();
+ else
+ ++ebi;
+ return modified;
+}
+
+void DeadBranchElimPass::Initialize(ir::Module* module) {
+
+ module_ = module;
+
+ // Initialize function and block maps
+ id2function_.clear();
+ id2block_.clear();
+ block2structured_succs_.clear();
+ for (auto& fn : *module_) {
+ // Initialize function and block maps.
+ id2function_[fn.result_id()] = &fn;
+ for (auto& blk : fn) {
+ id2block_[blk.id()] = &blk;
+ }
+ }
+
+ // TODO(greg-lunarg): Reuse def/use from previous passes
+ def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
+
+ // Initialize extension whitelist
+ InitExtensions();
+};
+
+void DeadBranchElimPass::FindNamedOrDecoratedIds() {
+ for (auto& di : module_->debugs())
+ if (di.opcode() == SpvOpName)
+ named_or_decorated_ids_.insert(di.GetSingleWordInOperand(0));
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpDecorate || ai.opcode() == SpvOpDecorateId)
+ named_or_decorated_ids_.insert(ai.GetSingleWordInOperand(0));
+}
+
+bool DeadBranchElimPass::AllExtensionsSupported() const {
+ // If any extension not in whitelist, return false
+ for (auto& ei : module_->extensions()) {
+ const char* extName = reinterpret_cast<const char*>(
+ &ei.GetInOperand(0).words[0]);
+ if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
+ return false;
+ }
+ return true;
+}
+
+Pass::Status DeadBranchElimPass::ProcessImpl() {
+ // Current functionality assumes structured control flow.
+ // TODO(greg-lunarg): Handle non-structured control-flow.
+ if (!module_->HasCapability(SpvCapabilityShader))
+ return Status::SuccessWithoutChange;
+ // Do not process if module contains OpGroupDecorate. Additional
+ // support required in KillNamesAndDecorates().
+ // TODO(greg-lunarg): Add support for OpGroupDecorate
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpGroupDecorate)
+ return Status::SuccessWithoutChange;
+ // Do not process if any disallowed extensions are enabled
+ if (!AllExtensionsSupported())
+ return Status::SuccessWithoutChange;
+ // Collect all named and decorated ids
+ FindNamedOrDecoratedIds();
+ // Process all entry point functions
+ bool modified = false;
+ for (const auto& e : module_->entry_points()) {
+ ir::Function* fn =
+ id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
+ modified = EliminateDeadBranches(fn) || modified;
+ }
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+DeadBranchElimPass::DeadBranchElimPass()
+ : module_(nullptr), def_use_mgr_(nullptr) {}
+
+Pass::Status DeadBranchElimPass::Process(ir::Module* module) {
+ Initialize(module);
+ return ProcessImpl();
+}
+
+void DeadBranchElimPass::InitExtensions() {
+ extensions_whitelist_.clear();
+ extensions_whitelist_.insert({
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
+ "SPV_AMD_gcn_shader",
+ "SPV_KHR_shader_ballot",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_KHR_shader_draw_parameters",
+ "SPV_KHR_subgroup_vote",
+ "SPV_KHR_16bit_storage",
+ "SPV_KHR_device_group",
+ "SPV_KHR_multiview",
+ "SPV_NVX_multiview_per_view_attributes",
+ "SPV_NV_viewport_array2",
+ "SPV_NV_stereo_view_rendering",
+ "SPV_NV_sample_mask_override_coverage",
+ "SPV_NV_geometry_shader_passthrough",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_KHR_storage_buffer_storage_class",
+ "SPV_KHR_variable_pointers",
+ "SPV_AMD_gpu_shader_int16",
+ "SPV_KHR_post_depth_coverage",
+ "SPV_KHR_shader_atomic_counter_ops",
+ });
+}
+
+} // namespace opt
+} // namespace spvtools
+