diff options
Diffstat (limited to 'source/lint/divergence_analysis.cpp')
-rw-r--r-- | source/lint/divergence_analysis.cpp | 245 |
1 files changed, 245 insertions, 0 deletions
diff --git a/source/lint/divergence_analysis.cpp b/source/lint/divergence_analysis.cpp new file mode 100644 index 00000000..b5a72b45 --- /dev/null +++ b/source/lint/divergence_analysis.cpp @@ -0,0 +1,245 @@ +// Copyright (c) 2021 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/lint/divergence_analysis.h" + +#include "source/opt/basic_block.h" +#include "source/opt/control_dependence.h" +#include "source/opt/dataflow.h" +#include "source/opt/function.h" +#include "source/opt/instruction.h" +#include "spirv/unified1/spirv.h" + +namespace spvtools { +namespace lint { + +void DivergenceAnalysis::EnqueueSuccessors(opt::Instruction* inst) { + // Enqueue control dependents of block, if applicable. + // There are two ways for a dependence source to be updated: + // 1. control -> control: source block is marked divergent. + // 2. data -> control: branch condition is marked divergent. + uint32_t block_id; + if (inst->IsBlockTerminator()) { + block_id = context().get_instr_block(inst)->id(); + } else if (inst->opcode() == SpvOpLabel) { + block_id = inst->result_id(); + opt::BasicBlock* bb = context().cfg()->block(block_id); + // Only enqueue phi instructions, as other uses don't affect divergence. + bb->ForEachPhiInst([this](opt::Instruction* phi) { Enqueue(phi); }); + } else { + opt::ForwardDataFlowAnalysis::EnqueueUsers(inst); + return; + } + if (!cd_.HasBlock(block_id)) { + return; + } + for (const spvtools::opt::ControlDependence& dep : + cd_.GetDependenceTargets(block_id)) { + opt::Instruction* target_inst = + context().cfg()->block(dep.target_bb_id())->GetLabelInst(); + Enqueue(target_inst); + } +} + +opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::Visit( + opt::Instruction* inst) { + if (inst->opcode() == SpvOpLabel) { + return VisitBlock(inst->result_id()); + } else { + return VisitInstruction(inst); + } +} + +opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitBlock(uint32_t id) { + if (!cd_.HasBlock(id)) { + return opt::DataFlowAnalysis::VisitResult::kResultFixed; + } + DivergenceLevel& cur_level = divergence_[id]; + if (cur_level == DivergenceLevel::kDivergent) { + return opt::DataFlowAnalysis::VisitResult::kResultFixed; + } + DivergenceLevel orig = cur_level; + for (const spvtools::opt::ControlDependence& dep : + cd_.GetDependenceSources(id)) { + if (divergence_[dep.source_bb_id()] > cur_level) { + cur_level = divergence_[dep.source_bb_id()]; + divergence_source_[id] = dep.source_bb_id(); + } else if (dep.source_bb_id() != 0) { + uint32_t condition_id = dep.GetConditionID(*context().cfg()); + DivergenceLevel dep_level = divergence_[condition_id]; + // Check if we are along the chain of unconditional branches starting from + // the branch target. + if (follow_unconditional_branches_[dep.branch_target_bb_id()] != + follow_unconditional_branches_[dep.target_bb_id()]) { + // We must have reconverged in order to reach this block. + // Promote partially uniform to divergent. + if (dep_level == DivergenceLevel::kPartiallyUniform) { + dep_level = DivergenceLevel::kDivergent; + } + } + if (dep_level > cur_level) { + cur_level = dep_level; + divergence_source_[id] = condition_id; + divergence_dependence_source_[id] = dep.source_bb_id(); + } + } + } + return cur_level > orig ? VisitResult::kResultChanged + : VisitResult::kResultFixed; +} + +opt::DataFlowAnalysis::VisitResult DivergenceAnalysis::VisitInstruction( + opt::Instruction* inst) { + if (inst->IsBlockTerminator()) { + // This is called only when the condition has changed, so return changed. + return VisitResult::kResultChanged; + } + if (!inst->HasResultId()) { + return VisitResult::kResultFixed; + } + uint32_t id = inst->result_id(); + DivergenceLevel& cur_level = divergence_[id]; + if (cur_level == DivergenceLevel::kDivergent) { + return opt::DataFlowAnalysis::VisitResult::kResultFixed; + } + DivergenceLevel orig = cur_level; + cur_level = ComputeInstructionDivergence(inst); + return cur_level > orig ? VisitResult::kResultChanged + : VisitResult::kResultFixed; +} + +DivergenceAnalysis::DivergenceLevel +DivergenceAnalysis::ComputeInstructionDivergence(opt::Instruction* inst) { + // TODO(kuhar): Check to see if inst is decorated with Uniform or UniformId + // and use that to short circuit other checks. Uniform is for subgroups which + // would satisfy derivative groups too. UniformId takes a scope, so if it is + // subgroup or greater it could satisfy derivative group and + // Device/QueueFamily could satisfy fully uniform. + uint32_t id = inst->result_id(); + // Handle divergence roots. + if (inst->opcode() == SpvOpFunctionParameter) { + divergence_source_[id] = 0; + return divergence_[id] = DivergenceLevel::kDivergent; + } else if (inst->IsLoad()) { + spvtools::opt::Instruction* var = inst->GetBaseAddress(); + if (var->opcode() != SpvOpVariable) { + // Assume divergent. + divergence_source_[id] = 0; + return DivergenceLevel::kDivergent; + } + DivergenceLevel ret = ComputeVariableDivergence(var); + if (ret > DivergenceLevel::kUniform) { + divergence_source_[inst->result_id()] = 0; + } + return divergence_[id] = ret; + } + // Get the maximum divergence of the operands. + DivergenceLevel ret = DivergenceLevel::kUniform; + inst->ForEachInId([this, inst, &ret](const uint32_t* op) { + if (!op) return; + if (divergence_[*op] > ret) { + divergence_source_[inst->result_id()] = *op; + ret = divergence_[*op]; + } + }); + divergence_[inst->result_id()] = ret; + return ret; +} + +DivergenceAnalysis::DivergenceLevel +DivergenceAnalysis::ComputeVariableDivergence(opt::Instruction* var) { + uint32_t type_id = var->type_id(); + spvtools::opt::analysis::Pointer* type = + context().get_type_mgr()->GetType(type_id)->AsPointer(); + assert(type != nullptr); + uint32_t def_id = var->result_id(); + DivergenceLevel ret; + switch (type->storage_class()) { + case SpvStorageClassFunction: + case SpvStorageClassGeneric: + case SpvStorageClassAtomicCounter: + case SpvStorageClassStorageBuffer: + case SpvStorageClassPhysicalStorageBuffer: + case SpvStorageClassOutput: + case SpvStorageClassWorkgroup: + case SpvStorageClassImage: // Image atomics probably aren't uniform. + case SpvStorageClassPrivate: + ret = DivergenceLevel::kDivergent; + break; + case SpvStorageClassInput: + ret = DivergenceLevel::kDivergent; + // If this variable has a Flat decoration, it is partially uniform. + // TODO(kuhar): Track access chain indices and also consider Flat members + // of a structure. + context().get_decoration_mgr()->WhileEachDecoration( + def_id, SpvDecorationFlat, [&ret](const opt::Instruction&) { + ret = DivergenceLevel::kPartiallyUniform; + return false; + }); + break; + case SpvStorageClassUniformConstant: + // May be a storage image which is also written to; mark those as + // divergent. + if (!var->IsVulkanStorageImage() || var->IsReadOnlyPointer()) { + ret = DivergenceLevel::kUniform; + } else { + ret = DivergenceLevel::kDivergent; + } + break; + case SpvStorageClassUniform: + case SpvStorageClassPushConstant: + case SpvStorageClassCrossWorkgroup: // Not for shaders; default uniform. + default: + ret = DivergenceLevel::kUniform; + break; + } + return ret; +} + +void DivergenceAnalysis::Setup(opt::Function* function) { + // TODO(kuhar): Run functions called by |function| so we can detect + // reconvergence caused by multiple returns. + cd_.ComputeControlDependenceGraph( + *context().cfg(), *context().GetPostDominatorAnalysis(function)); + context().cfg()->ForEachBlockInPostOrder( + function->entry().get(), [this](const opt::BasicBlock* bb) { + uint32_t id = bb->id(); + if (bb->terminator() == nullptr || + bb->terminator()->opcode() != SpvOpBranch) { + follow_unconditional_branches_[id] = id; + } else { + uint32_t target_id = bb->terminator()->GetSingleWordInOperand(0); + // Target is guaranteed to have been visited before us in postorder. + follow_unconditional_branches_[id] = + follow_unconditional_branches_[target_id]; + } + }); +} + +std::ostream& operator<<(std::ostream& os, + DivergenceAnalysis::DivergenceLevel level) { + switch (level) { + case DivergenceAnalysis::DivergenceLevel::kUniform: + return os << "uniform"; + case DivergenceAnalysis::DivergenceLevel::kPartiallyUniform: + return os << "partially uniform"; + case DivergenceAnalysis::DivergenceLevel::kDivergent: + return os << "divergent"; + default: + return os << "<invalid divergence level>"; + } +} + +} // namespace lint +} // namespace spvtools |