diff options
Diffstat (limited to 'source/opt')
96 files changed, 1535 insertions, 632 deletions
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 7d522fb5..61e7a981 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -45,6 +45,7 @@ set(SPIRV_TOOLS_OPT_SOURCES eliminate_dead_constant_pass.h eliminate_dead_functions_pass.h eliminate_dead_functions_util.h + eliminate_dead_input_components_pass.h eliminate_dead_members_pass.h empty_pass.h feature_manager.h @@ -99,6 +100,7 @@ set(SPIRV_TOOLS_OPT_SOURCES reflect.h register_pressure.h relax_float_ops_pass.h + remove_dontinline_pass.h remove_duplicates_pass.h remove_unused_interface_variables_pass.h replace_desc_array_access_using_var_index.h @@ -108,10 +110,11 @@ set(SPIRV_TOOLS_OPT_SOURCES scalar_replacement_pass.h set_spec_constant_default_value_pass.h simplification_pass.h + spread_volatile_semantics.h ssa_rewrite_pass.h strength_reduction_pass.h strip_debug_info_pass.h - strip_reflect_info_pass.h + strip_nonsemantic_info_pass.h struct_cfg_analysis.h tree_iterator.h type_manager.h @@ -156,6 +159,7 @@ set(SPIRV_TOOLS_OPT_SOURCES eliminate_dead_constant_pass.cpp eliminate_dead_functions_pass.cpp eliminate_dead_functions_util.cpp + eliminate_dead_input_components_pass.cpp eliminate_dead_members_pass.cpp feature_manager.cpp fix_storage_class.cpp @@ -206,6 +210,7 @@ set(SPIRV_TOOLS_OPT_SOURCES redundancy_elimination.cpp register_pressure.cpp relax_float_ops_pass.cpp + remove_dontinline_pass.cpp remove_duplicates_pass.cpp remove_unused_interface_variables_pass.cpp replace_desc_array_access_using_var_index.cpp @@ -215,10 +220,11 @@ set(SPIRV_TOOLS_OPT_SOURCES scalar_replacement_pass.cpp set_spec_constant_default_value_pass.cpp simplification_pass.cpp + spread_volatile_semantics.cpp ssa_rewrite_pass.cpp strength_reduction_pass.cpp strip_debug_info_pass.cpp - strip_reflect_info_pass.cpp + strip_nonsemantic_info_pass.cpp struct_cfg_analysis.cpp type_manager.cpp types.cpp diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index 0b54d5e8..04737521 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -27,6 +27,7 @@ #include "source/opt/iterator.h" #include "source/opt/reflect.h" #include "source/spirv_constant.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -146,8 +147,7 @@ void AggressiveDCEPass::AddStores(Function* func, uint32_t ptrId) { bool AggressiveDCEPass::AllExtensionsSupported() const { // If any extension not in allowlist, return false for (auto& ei : get_module()->extensions()) { - const char* extName = - reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]); + const std::string extName = ei.GetInOperand(0).AsString(); if (extensions_allowlist_.find(extName) == extensions_allowlist_.end()) return false; } @@ -156,11 +156,9 @@ bool AggressiveDCEPass::AllExtensionsSupported() const { for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12) && - 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100", - 32)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.") && + extension_name != "NonSemantic.Shader.DebugInfo.100") { return false; } } @@ -569,12 +567,7 @@ void AggressiveDCEPass::InitializeModuleScopeLiveInstructions() { } // Keep all entry points. for (auto& entry : get_module()->entry_points()) { - if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4) && - !preserve_interface_) { - // In SPIR-V 1.4 and later, entry points must list all global variables - // used. DCE can still remove non-input/output variables and update the - // interface list. Mark the entry point as live and inputs and outputs as - // live, but defer decisions all other interfaces. + if (!preserve_interface_) { live_insts_.Set(entry.unique_id()); // The actual function is live always. AddToWorklist( @@ -582,8 +575,9 @@ void AggressiveDCEPass::InitializeModuleScopeLiveInstructions() { for (uint32_t i = 3; i < entry.NumInOperands(); ++i) { auto* var = get_def_use_mgr()->GetDef(entry.GetSingleWordInOperand(i)); auto storage_class = var->GetSingleWordInOperand(0u); - if (storage_class == SpvStorageClassInput || - storage_class == SpvStorageClassOutput) { + // Vulkan support outputs without an associated input, but not inputs + // without an associated output. + if (storage_class == SpvStorageClassOutput) { AddToWorklist(var); } } @@ -885,8 +879,7 @@ bool AggressiveDCEPass::ProcessGlobalValues() { } } - if (get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 4) && - !preserve_interface_) { + if (!preserve_interface_) { // Remove the dead interface variables from the entry point interface list. for (auto& entry : get_module()->entry_points()) { std::vector<Operand> new_operands; @@ -974,6 +967,7 @@ void AggressiveDCEPass::InitExtensions() { "SPV_KHR_integer_dot_product", "SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info", + "SPV_KHR_uniform_group_instructions", }); } diff --git a/source/opt/amd_ext_to_khr.cpp b/source/opt/amd_ext_to_khr.cpp index ccedc0bc..dd9bafda 100644 --- a/source/opt/amd_ext_to_khr.cpp +++ b/source/opt/amd_ext_to_khr.cpp @@ -584,9 +584,9 @@ bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst, } // Get the constants that will be used. - uint32_t f0_const_id = const_mgr->GetFloatConst(0.0); - uint32_t f2_const_id = const_mgr->GetFloatConst(2.0); - uint32_t f0_5_const_id = const_mgr->GetFloatConst(0.5); + uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0); + uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0); + uint32_t f0_5_const_id = const_mgr->GetFloatConstId(0.5); const analysis::Constant* vec_const = const_mgr->GetConstant(v2_float_type, {f0_5_const_id, f0_5_const_id}); uint32_t vec_const_id = @@ -731,12 +731,12 @@ bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst, } // Get the constants that will be used. - uint32_t f0_const_id = const_mgr->GetFloatConst(0.0); - uint32_t f1_const_id = const_mgr->GetFloatConst(1.0); - uint32_t f2_const_id = const_mgr->GetFloatConst(2.0); - uint32_t f3_const_id = const_mgr->GetFloatConst(3.0); - uint32_t f4_const_id = const_mgr->GetFloatConst(4.0); - uint32_t f5_const_id = const_mgr->GetFloatConst(5.0); + uint32_t f0_const_id = const_mgr->GetFloatConstId(0.0); + uint32_t f1_const_id = const_mgr->GetFloatConstId(1.0); + uint32_t f2_const_id = const_mgr->GetFloatConstId(2.0); + uint32_t f3_const_id = const_mgr->GetFloatConstId(3.0); + uint32_t f4_const_id = const_mgr->GetFloatConstId(4.0); + uint32_t f5_const_id = const_mgr->GetFloatConstId(5.0); // Extract the input values. Instruction* x = ir_builder.AddCompositeExtract(float_type_id, input_id, {0}); @@ -935,8 +935,7 @@ Pass::Status AmdExtensionToKhrPass::Process() { std::vector<Instruction*> to_be_killed; for (Instruction& inst : context()->module()->extensions()) { if (inst.opcode() == SpvOpExtension) { - if (ext_to_remove.count(reinterpret_cast<const char*>( - &(inst.GetInOperand(0).words[0]))) != 0) { + if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) { to_be_killed.push_back(&inst); } } @@ -944,8 +943,7 @@ Pass::Status AmdExtensionToKhrPass::Process() { for (Instruction& inst : context()->ext_inst_imports()) { if (inst.opcode() == SpvOpExtInstImport) { - if (ext_to_remove.count(reinterpret_cast<const char*>( - &(inst.GetInOperand(0).words[0]))) != 0) { + if (ext_to_remove.count(inst.GetInOperand(0).AsString()) != 0) { to_be_killed.push_back(&inst); } } diff --git a/source/opt/amd_ext_to_khr.h b/source/opt/amd_ext_to_khr.h index fd3dab4e..6a39d953 100644 --- a/source/opt/amd_ext_to_khr.h +++ b/source/opt/amd_ext_to_khr.h @@ -23,7 +23,7 @@ namespace spvtools { namespace opt { // Replaces the extensions VK_AMD_shader_ballot, VK_AMD_gcn_shader, and -// VK_AMD_shader_trinary_minmax with equivalant code using core instructions and +// VK_AMD_shader_trinary_minmax with equivalent code using core instructions and // capabilities. class AmdExtensionToKhrPass : public Pass { public: diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h index 6741a50f..dd3b2e28 100644 --- a/source/opt/basic_block.h +++ b/source/opt/basic_block.h @@ -83,7 +83,7 @@ class BasicBlock { const Instruction* GetMergeInst() const; Instruction* GetMergeInst(); - // Returns the OpLoopMerge instruciton in this basic block, if it exists. + // Returns the OpLoopMerge instruction in this basic block, if it exists. // Otherwise return null. May be used whenever tail() can be used. const Instruction* GetLoopMergeInst() const; Instruction* GetLoopMergeInst(); diff --git a/source/opt/ccp_pass.cpp b/source/opt/ccp_pass.cpp index 8b896d50..5f855027 100644 --- a/source/opt/ccp_pass.cpp +++ b/source/opt/ccp_pass.cpp @@ -102,6 +102,34 @@ SSAPropagator::PropStatus CCPPass::VisitPhi(Instruction* phi) { return SSAPropagator::kInteresting; } +uint32_t CCPPass::ComputeLatticeMeet(Instruction* instr, uint32_t val2) { + // Given two values val1 and val2, the meet operation in the constant + // lattice uses the following rules: + // + // meet(val1, UNDEFINED) = val1 + // meet(val1, VARYING) = VARYING + // meet(val1, val2) = val1 if val1 == val2 + // meet(val1, val2) = VARYING if val1 != val2 + // + // When two different values meet, the result is always varying because CCP + // does not allow lateral transitions in the lattice. This prevents + // infinite cycles during propagation. + auto val1_it = values_.find(instr->result_id()); + if (val1_it == values_.end()) { + return val2; + } + + uint32_t val1 = val1_it->second; + if (IsVaryingValue(val1)) { + return val1; + } else if (IsVaryingValue(val2)) { + return val2; + } else if (val1 != val2) { + return kVaryingSSAId; + } + return val2; +} + SSAPropagator::PropStatus CCPPass::VisitAssignment(Instruction* instr) { assert(instr->result_id() != 0 && "Expecting an instruction that produces a result"); @@ -115,8 +143,10 @@ SSAPropagator::PropStatus CCPPass::VisitAssignment(Instruction* instr) { if (IsVaryingValue(it->second)) { return MarkInstructionVarying(instr); } else { - values_[instr->result_id()] = it->second; - return SSAPropagator::kInteresting; + uint32_t new_val = ComputeLatticeMeet(instr, it->second); + values_[instr->result_id()] = new_val; + return IsVaryingValue(new_val) ? SSAPropagator::kVarying + : SSAPropagator::kInteresting; } } return SSAPropagator::kNotInteresting; @@ -142,9 +172,13 @@ SSAPropagator::PropStatus CCPPass::VisitAssignment(Instruction* instr) { if (folded_inst != nullptr) { // We do not want to change the body of the function by adding new // instructions. When folding we can only generate new constants. - assert(folded_inst->IsConstant() && "CCP is only interested in constant."); - values_[instr->result_id()] = folded_inst->result_id(); - return SSAPropagator::kInteresting; + assert((folded_inst->IsConstant() || + IsSpecConstantInst(folded_inst->opcode())) && + "CCP is only interested in constant values."); + uint32_t new_val = ComputeLatticeMeet(instr, folded_inst->result_id()); + values_[instr->result_id()] = new_val; + return IsVaryingValue(new_val) ? SSAPropagator::kVarying + : SSAPropagator::kInteresting; } // Conservatively mark this instruction as varying if any input id is varying. diff --git a/source/opt/ccp_pass.h b/source/opt/ccp_pass.h index fb20c780..77ea9f80 100644 --- a/source/opt/ccp_pass.h +++ b/source/opt/ccp_pass.h @@ -92,6 +92,22 @@ class CCPPass : public MemPass { // generated during propagation. analysis::ConstantManager* const_mgr_; + // Returns a new value for |instr| by computing the meet operation between + // its existing value and |val2|. + // + // Given two values val1 and val2, the meet operation in the constant + // lattice uses the following rules: + // + // meet(val1, UNDEFINED) = val1 + // meet(val1, VARYING) = VARYING + // meet(val1, val2) = val1 if val1 == val2 + // meet(val1, val2) = VARYING if val1 != val2 + // + // When two different values meet, the result is always varying because CCP + // does not allow lateral transitions in the lattice. This prevents + // infinite cycles during propagation. + uint32_t ComputeLatticeMeet(Instruction* instr, uint32_t val2); + // Constant value table. Each entry <id, const_decl_id> in this map // represents the compile-time constant value for |id| as declared by // |const_decl_id|. Each |const_decl_id| in this table is an OpConstant diff --git a/source/opt/cfg.h b/source/opt/cfg.h index f2806822..33412f18 100644 --- a/source/opt/cfg.h +++ b/source/opt/cfg.h @@ -30,7 +30,7 @@ class CFG { public: explicit CFG(Module* module); - // Return the list of predecesors for basic block with label |blkid|. + // Return the list of predecessors for basic block with label |blkid|. // TODO(dnovillo): Move this to BasicBlock. const std::vector<uint32_t>& preds(uint32_t blk_id) const { assert(label2preds_.count(blk_id)); diff --git a/source/opt/compact_ids_pass.cpp b/source/opt/compact_ids_pass.cpp index 8815b8c6..70848d79 100644 --- a/source/opt/compact_ids_pass.cpp +++ b/source/opt/compact_ids_pass.cpp @@ -86,7 +86,8 @@ Pass::Status CompactIdsPass::Process() { }, true); - if (modified) { + if (context()->module()->id_bound() != result_id_mapping.size() + 1) { + modified = true; context()->module()->SetIdBound( static_cast<uint32_t>(result_id_mapping.size() + 1)); // There are ids in the feature manager that could now be invalid diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index 515a3ed5..249e11e5 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -22,6 +22,45 @@ namespace { const uint32_t kExtractCompositeIdInIdx = 0; +// Returns a constants with the value NaN of the given type. Only works for +// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. +const analysis::Constant* GetNan(const analysis::Type* type, + analysis::ConstantManager* const_mgr) { + const analysis::Float* float_type = type->AsFloat(); + if (float_type == nullptr) { + return nullptr; + } + + switch (float_type->width()) { + case 32: + return const_mgr->GetFloatConst(std::numeric_limits<float>::quiet_NaN()); + case 64: + return const_mgr->GetDoubleConst( + std::numeric_limits<double>::quiet_NaN()); + default: + return nullptr; + } +} + +// Returns a constants with the value INF of the given type. Only works for +// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. +const analysis::Constant* GetInf(const analysis::Type* type, + analysis::ConstantManager* const_mgr) { + const analysis::Float* float_type = type->AsFloat(); + if (float_type == nullptr) { + return nullptr; + } + + switch (float_type->width()) { + case 32: + return const_mgr->GetFloatConst(std::numeric_limits<float>::infinity()); + case 64: + return const_mgr->GetDoubleConst(std::numeric_limits<double>::infinity()); + default: + return nullptr; + } +} + // Returns true if |type| is Float or a vector of Float. bool HasFloatingPoint(const analysis::Type* type) { if (type->AsFloat()) { @@ -33,6 +72,23 @@ bool HasFloatingPoint(const analysis::Type* type) { return false; } +// Returns a constants with the value |-val| of the given type. Only works for +// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs. +const analysis::Constant* negateFPConst(const analysis::Type* result_type, + const analysis::Constant* val, + analysis::ConstantManager* const_mgr) { + const analysis::Float* float_type = result_type->AsFloat(); + assert(float_type != nullptr); + if (float_type->width() == 32) { + float fa = val->GetFloat(); + return const_mgr->GetFloatConst(-fa); + } else if (float_type->width() == 64) { + double da = val->GetDouble(); + return const_mgr->GetDoubleConst(-da); + } + return nullptr; +} + // Folds an OpcompositeExtract where input is a composite constant. ConstantFoldingRule FoldExtractWithConstants() { return [](IRContext* context, Instruction* inst, @@ -492,7 +548,60 @@ ConstantFoldingRule FoldQuantizeToF16() { ConstantFoldingRule FoldFSub() { return FoldFPBinaryOp(FOLD_FPARITH_OP(-)); } ConstantFoldingRule FoldFAdd() { return FoldFPBinaryOp(FOLD_FPARITH_OP(+)); } ConstantFoldingRule FoldFMul() { return FoldFPBinaryOp(FOLD_FPARITH_OP(*)); } -ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FOLD_FPARITH_OP(/)); } + +// Returns the constant that results from evaluating |numerator| / 0.0. Returns +// |nullptr| if the result could not be evaluated. +const analysis::Constant* FoldFPScalarDivideByZero( + const analysis::Type* result_type, const analysis::Constant* numerator, + analysis::ConstantManager* const_mgr) { + if (numerator == nullptr) { + return nullptr; + } + + if (numerator->IsZero()) { + return GetNan(result_type, const_mgr); + } + + const analysis::Constant* result = GetInf(result_type, const_mgr); + if (result == nullptr) { + return nullptr; + } + + if (numerator->AsFloatConstant()->GetValueAsDouble() < 0.0) { + result = negateFPConst(result_type, result, const_mgr); + } + return result; +} + +// Returns the result of folding |numerator| / |denominator|. Returns |nullptr| +// if it cannot be folded. +const analysis::Constant* FoldScalarFPDivide( + const analysis::Type* result_type, const analysis::Constant* numerator, + const analysis::Constant* denominator, + analysis::ConstantManager* const_mgr) { + if (denominator == nullptr) { + return nullptr; + } + + if (denominator->IsZero()) { + return FoldFPScalarDivideByZero(result_type, numerator, const_mgr); + } + + const analysis::FloatConstant* denominator_float = + denominator->AsFloatConstant(); + if (denominator_float && denominator->GetValueAsDouble() == -0.0) { + const analysis::Constant* result = + FoldFPScalarDivideByZero(result_type, numerator, const_mgr); + if (result != nullptr) + result = negateFPConst(result_type, result, const_mgr); + return result; + } else { + return FOLD_FPARITH_OP(/)(result_type, numerator, denominator, const_mgr); + } +} + +// Returns the constant folding rule to fold |OpFDiv| with two constants. +ConstantFoldingRule FoldFDiv() { return FoldFPBinaryOp(FoldScalarFPDivide); } bool CompareFloatingPoint(bool op_result, bool op_unordered, bool need_ordered) { @@ -655,20 +764,7 @@ UnaryScalarFoldingRule FoldFNegateOp() { analysis::ConstantManager* const_mgr) -> const analysis::Constant* { assert(result_type != nullptr && a != nullptr); assert(result_type == a->type()); - const analysis::Float* float_type = result_type->AsFloat(); - assert(float_type != nullptr); - if (float_type->width() == 32) { - float fa = a->GetFloat(); - utils::FloatProxy<float> result(-fa); - std::vector<uint32_t> words = result.GetWords(); - return const_mgr->GetConstant(result_type, words); - } else if (float_type->width() == 64) { - double da = a->GetDouble(); - utils::FloatProxy<double> result(-da); - std::vector<uint32_t> words = result.GetWords(); - return const_mgr->GetConstant(result_type, words); - } - return nullptr; + return negateFPConst(result_type, a, const_mgr); }; } @@ -1250,7 +1346,7 @@ void ConstantFoldingRules::AddFoldingRules() { FoldFPUnaryOp(FoldFTranscendentalUnary(std::log))); #ifdef __ANDROID__ - // Android NDK r15c tageting ABI 15 doesn't have full support for C++11 + // Android NDK r15c targeting ABI 15 doesn't have full support for C++11 // (no std::exp2/log2). ::exp2 is available from C99 but ::log2 isn't // available up until ABI 18 so we use a shim auto log2_shim = [](double v) -> double { return log(v) / log(2.0); }; diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp index 020e248b..bcff08c1 100644 --- a/source/opt/constants.cpp +++ b/source/opt/constants.cpp @@ -158,6 +158,7 @@ Type* ConstantManager::GetType(const Instruction* inst) const { std::vector<const Constant*> ConstantManager::GetOperandConstants( const Instruction* inst) const { std::vector<const Constant*> constants; + constants.reserve(inst->NumInOperands()); for (uint32_t i = 0; i < inst->NumInOperands(); i++) { const Operand* operand = &inst->GetInOperand(i); if (operand->type != SPV_OPERAND_TYPE_ID) { @@ -420,13 +421,30 @@ const Constant* ConstantManager::GetNumericVectorConstantWithWords( return GetConstant(type, element_ids); } -uint32_t ConstantManager::GetFloatConst(float val) { +uint32_t ConstantManager::GetFloatConstId(float val) { + const Constant* c = GetFloatConst(val); + return GetDefiningInstruction(c)->result_id(); +} + +const Constant* ConstantManager::GetFloatConst(float val) { Type* float_type = context()->get_type_mgr()->GetFloatType(); utils::FloatProxy<float> v(val); const Constant* c = GetConstant(float_type, v.GetWords()); + return c; +} + +uint32_t ConstantManager::GetDoubleConstId(double val) { + const Constant* c = GetDoubleConst(val); return GetDefiningInstruction(c)->result_id(); } +const Constant* ConstantManager::GetDoubleConst(double val) { + Type* float_type = context()->get_type_mgr()->GetDoubleType(); + utils::FloatProxy<double> v(val); + const Constant* c = GetConstant(float_type, v.GetWords()); + return c; +} + uint32_t ConstantManager::GetSIntConst(int32_t val) { Type* sint_type = context()->get_type_mgr()->GetSIntType(); const Constant* c = GetConstant(sint_type, {static_cast<uint32_t>(val)}); diff --git a/source/opt/constants.h b/source/opt/constants.h index 52bd809a..c039ae08 100644 --- a/source/opt/constants.h +++ b/source/opt/constants.h @@ -541,7 +541,7 @@ class ConstantManager { // instruction at the end of the current module's types section. // // |type_id| is an optional argument for disambiguating equivalent types. If - // |type_id| is specified, the contant returned will have that type id. + // |type_id| is specified, the constant returned will have that type id. Instruction* GetDefiningInstruction(const Constant* c, uint32_t type_id = 0, Module::inst_iterator* pos = nullptr); @@ -637,7 +637,16 @@ class ConstantManager { } // Returns the id of a 32-bit floating point constant with value |val|. - uint32_t GetFloatConst(float val); + uint32_t GetFloatConstId(float val); + + // Returns a 32-bit float constant with the given value. + const Constant* GetFloatConst(float val); + + // Returns the id of a 64-bit floating point constant with value |val|. + uint32_t GetDoubleConstId(double val); + + // Returns a 64-bit float constant with the given value. + const Constant* GetDoubleConst(double val); // Returns the id of a 32-bit signed integer constant with value |val|. uint32_t GetSIntConst(int32_t val); diff --git a/source/opt/convert_to_half_pass.cpp b/source/opt/convert_to_half_pass.cpp index b127eabe..4086e31a 100644 --- a/source/opt/convert_to_half_pass.cpp +++ b/source/opt/convert_to_half_pass.cpp @@ -181,7 +181,7 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width, uint32_t to_width) { // Add converts of any float operands to to_width if they are of from_width. // If converting to 16, change type of phi to float16 equivalent and remember - // result id. Converts need to be added to preceeding blocks. + // result id. Converts need to be added to preceding blocks. uint32_t ocnt = 0; uint32_t* prev_idp; bool modified = false; diff --git a/source/opt/copy_prop_arrays.cpp b/source/opt/copy_prop_arrays.cpp index 62ed5e77..321d4969 100644 --- a/source/opt/copy_prop_arrays.cpp +++ b/source/opt/copy_prop_arrays.cpp @@ -745,11 +745,11 @@ void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst, context()->AnalyzeUses(use); } break; + case SpvOpDecorate: + // We treat an OpImageTexelPointer as a load. The result type should + // always have the Image storage class, and should not need to be + // updated. case SpvOpImageTexelPointer: - // We treat an OpImageTexelPointer as a load. The result type should - // always have the Image storage class, and should not need to be - // updated. - // Replace the actual use. context()->ForgetUses(use); use->SetOperand(index, {new_ptr_inst->result_id()}); diff --git a/source/opt/copy_prop_arrays.h b/source/opt/copy_prop_arrays.h index f4314a74..46a508cf 100644 --- a/source/opt/copy_prop_arrays.h +++ b/source/opt/copy_prop_arrays.h @@ -35,7 +35,7 @@ namespace opt { // // The hard part is keeping all of the types correct. We do not want to // have to do too large a search to update everything, which may not be -// possible, do we give up if we see any instruction that might be hard to +// possible, so we give up if we see any instruction that might be hard to // update. class CopyPropagateArrays : public MemPass { diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp index 356dbcb3..cc616ca6 100644 --- a/source/opt/dead_branch_elim_pass.cpp +++ b/source/opt/dead_branch_elim_pass.cpp @@ -207,7 +207,7 @@ bool DeadBranchElimPass::SimplifyBranch(BasicBlock* block, Instruction::OperandList new_operands; new_operands.push_back(terminator->GetInOperand(0)); new_operands.push_back({SPV_OPERAND_TYPE_ID, {live_lab_id}}); - terminator->SetInOperands(move(new_operands)); + terminator->SetInOperands(std::move(new_operands)); context()->UpdateDefUse(terminator); } else { // Check if the merge instruction is still needed because of a diff --git a/source/opt/dead_branch_elim_pass.h b/source/opt/dead_branch_elim_pass.h index 7841bc47..198bad2d 100644 --- a/source/opt/dead_branch_elim_pass.h +++ b/source/opt/dead_branch_elim_pass.h @@ -98,7 +98,7 @@ class DeadBranchElimPass : public MemPass { // Fix phis in reachable blocks so that only live (or unremovable) incoming // edges are present. If the block now only has a single live incoming edge, // remove the phi and replace its uses with its data input. If the single - // remaining incoming edge is from the phi itself, the the phi is in an + // remaining incoming edge is from the phi itself, the phi is in an // unreachable single block loop. Either the block is dead and will be // removed, or it's reachable from an unreachable continue target. In the // latter case that continue target block will be collapsed into a block that @@ -158,7 +158,7 @@ class DeadBranchElimPass : public MemPass { uint32_t cont_id, uint32_t header_id, uint32_t merge_id, std::unordered_set<BasicBlock*>* blocks_with_back_edges); - // Returns true if there is a brach to the merge node of the selection + // Returns true if there is a branch to the merge node of the selection // construct |switch_header_id| that is inside a nested selection construct or // in the header of the nested selection construct. bool SwitchHasNestedBreak(uint32_t switch_header_id); diff --git a/source/opt/debug_info_manager.cpp b/source/opt/debug_info_manager.cpp index 060e0d93..c1df6258 100644 --- a/source/opt/debug_info_manager.cpp +++ b/source/opt/debug_info_manager.cpp @@ -149,7 +149,7 @@ void DebugInfoManager::RegisterDbgDeclare(uint32_t var_id, // Create new constant directly into global value area, bypassing the // Constant manager. This is used when the DefUse or Constant managers // are invalid and cannot be regenerated due to the module being in an -// inconsistant state e.g. in the middle of significant modification +// inconsistent state e.g. in the middle of significant modification // such as inlining. Invalidate Constant and DefUse managers if used. uint32_t AddNewConstInGlobals(IRContext* context, uint32_t const_value) { uint32_t id = context->TakeNextId(); diff --git a/source/opt/def_use_manager.cpp b/source/opt/def_use_manager.cpp index 394b9fa1..e1e441e0 100644 --- a/source/opt/def_use_manager.cpp +++ b/source/opt/def_use_manager.cpp @@ -13,16 +13,23 @@ // limitations under the License. #include "source/opt/def_use_manager.h" - -#include <iostream> - -#include "source/opt/log.h" -#include "source/opt/reflect.h" +#include "source/util/make_unique.h" namespace spvtools { namespace opt { namespace analysis { +// Don't compact before we have a reasonable number of ids allocated (~32kb). +static const size_t kCompactThresholdMinTotalIds = (8 * 1024); +// Compact when fewer than this fraction of the storage is used (should be 2^n +// for performance). +static const size_t kCompactThresholdFractionFreeIds = 8; + +DefUseManager::DefUseManager() { + use_pool_ = MakeUnique<UseListPool>(); + used_id_pool_ = MakeUnique<UsedIdListPool>(); +} + void DefUseManager::AnalyzeInstDef(Instruction* inst) { const uint32_t def_id = inst->result_id(); if (def_id != 0) { @@ -39,15 +46,15 @@ void DefUseManager::AnalyzeInstDef(Instruction* inst) { } void DefUseManager::AnalyzeInstUse(Instruction* inst) { + // It might have existed before. + EraseUseRecordsOfOperandIds(inst); + // Create entry for the given instruction. Note that the instruction may // not have any in-operands. In such cases, we still need a entry for those // instructions so this manager knows it has seen the instruction later. - auto* used_ids = &inst_to_used_ids_[inst]; - if (used_ids->size()) { - EraseUseRecordsOfOperandIds(inst); - used_ids = &inst_to_used_ids_[inst]; - } - used_ids->clear(); // It might have existed before. + UsedIdList& used_ids = + inst_to_used_id_.insert({inst, UsedIdList(used_id_pool_.get())}) + .first->second; for (uint32_t i = 0; i < inst->NumOperands(); ++i) { switch (inst->GetOperand(i).type) { @@ -58,9 +65,18 @@ void DefUseManager::AnalyzeInstUse(Instruction* inst) { case SPV_OPERAND_TYPE_SCOPE_ID: { uint32_t use_id = inst->GetSingleWordOperand(i); Instruction* def = GetDef(use_id); - if (!def) assert(false && "Definition is not registered."); - id_to_users_.insert(UserEntry(def, inst)); - used_ids->push_back(use_id); + assert(def && "Definition is not registered."); + + // Add to inst's use records + used_ids.push_back(use_id); + + // Add to the users, taking care to avoid adding duplicates. We know + // the duplicate for this instruction will always be at the tail. + UseList& list = inst_to_users_.insert({def, UseList(use_pool_.get())}) + .first->second; + if (list.empty() || list.back() != inst) { + list.push_back(inst); + } } break; default: break; @@ -99,23 +115,6 @@ const Instruction* DefUseManager::GetDef(uint32_t id) const { return iter->second; } -DefUseManager::IdToUsersMap::const_iterator DefUseManager::UsersBegin( - const Instruction* def) const { - return id_to_users_.lower_bound( - UserEntry(const_cast<Instruction*>(def), nullptr)); -} - -bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter, - const IdToUsersMap::const_iterator& cached_end, - const Instruction* inst) const { - return (iter != cached_end && iter->first == inst); -} - -bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter, - const Instruction* inst) const { - return UsersNotEnd(iter, id_to_users_.end(), inst); -} - bool DefUseManager::WhileEachUser( const Instruction* def, const std::function<bool(Instruction*)>& f) const { // Ensure that |def| has been registered. @@ -123,9 +122,11 @@ bool DefUseManager::WhileEachUser( "Definition is not registered."); if (!def->HasResultId()) return true; - auto end = id_to_users_.end(); - for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) { - if (!f(iter->second)) return false; + auto iter = inst_to_users_.find(def); + if (iter != inst_to_users_.end()) { + for (Instruction* user : iter->second) { + if (!f(user)) return false; + } } return true; } @@ -156,14 +157,15 @@ bool DefUseManager::WhileEachUse( "Definition is not registered."); if (!def->HasResultId()) return true; - auto end = id_to_users_.end(); - for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) { - Instruction* user = iter->second; - for (uint32_t idx = 0; idx != user->NumOperands(); ++idx) { - const Operand& op = user->GetOperand(idx); - if (op.type != SPV_OPERAND_TYPE_RESULT_ID && spvIsIdType(op.type)) { - if (def->result_id() == op.words[0]) { - if (!f(user, idx)) return false; + auto iter = inst_to_users_.find(def); + if (iter != inst_to_users_.end()) { + for (Instruction* user : iter->second) { + for (uint32_t idx = 0; idx != user->NumOperands(); ++idx) { + const Operand& op = user->GetOperand(idx); + if (op.type != SPV_OPERAND_TYPE_RESULT_ID && spvIsIdType(op.type)) { + if (def->result_id() == op.words[0]) { + if (!f(user, idx)) return false; + } } } } @@ -235,17 +237,18 @@ void DefUseManager::AnalyzeDefUse(Module* module) { } void DefUseManager::ClearInst(Instruction* inst) { - auto iter = inst_to_used_ids_.find(inst); - if (iter != inst_to_used_ids_.end()) { + if (inst_to_used_id_.find(inst) != inst_to_used_id_.end()) { EraseUseRecordsOfOperandIds(inst); - if (inst->result_id() != 0) { - // Remove all uses of this inst. - auto users_begin = UsersBegin(inst); - auto end = id_to_users_.end(); - auto new_end = users_begin; - for (; UsersNotEnd(new_end, end, inst); ++new_end) { + uint32_t const result_id = inst->result_id(); + if (result_id != 0) { + // For each using instruction, remove result_id from their used ids. + auto iter = inst_to_users_.find(inst); + if (iter != inst_to_users_.end()) { + for (Instruction* use : iter->second) { + inst_to_used_id_.at(use).remove_first(result_id); + } + inst_to_users_.erase(iter); } - id_to_users_.erase(users_begin, new_end); id_to_def_.erase(inst->result_id()); } } @@ -254,59 +257,113 @@ void DefUseManager::ClearInst(Instruction* inst) { void DefUseManager::EraseUseRecordsOfOperandIds(const Instruction* inst) { // Go through all ids used by this instruction, remove this instruction's // uses of them. - auto iter = inst_to_used_ids_.find(inst); - if (iter != inst_to_used_ids_.end()) { - for (auto use_id : iter->second) { - id_to_users_.erase( - UserEntry(GetDef(use_id), const_cast<Instruction*>(inst))); + auto iter = inst_to_used_id_.find(inst); + if (iter != inst_to_used_id_.end()) { + const UsedIdList& used_ids = iter->second; + for (uint32_t def_id : used_ids) { + auto def_iter = inst_to_users_.find(GetDef(def_id)); + if (def_iter != inst_to_users_.end()) { + def_iter->second.remove_first(const_cast<Instruction*>(inst)); + } + } + inst_to_used_id_.erase(inst); + + // If we're using only a fraction of the space in used_ids_, compact storage + // to prevent memory usage from being unbounded. + if (used_id_pool_->total_nodes() > kCompactThresholdMinTotalIds && + used_id_pool_->used_nodes() < + used_id_pool_->total_nodes() / kCompactThresholdFractionFreeIds) { + CompactStorage(); } - inst_to_used_ids_.erase(inst); } } -bool operator==(const DefUseManager& lhs, const DefUseManager& rhs) { +void DefUseManager::CompactStorage() { + CompactUseRecords(); + CompactUsedIds(); +} + +void DefUseManager::CompactUseRecords() { + std::unique_ptr<UseListPool> new_pool = MakeUnique<UseListPool>(); + for (auto& iter : inst_to_users_) { + iter.second.move_nodes(new_pool.get()); + } + use_pool_ = std::move(new_pool); +} + +void DefUseManager::CompactUsedIds() { + std::unique_ptr<UsedIdListPool> new_pool = MakeUnique<UsedIdListPool>(); + for (auto& iter : inst_to_used_id_) { + iter.second.move_nodes(new_pool.get()); + } + used_id_pool_ = std::move(new_pool); +} + +bool CompareAndPrintDifferences(const DefUseManager& lhs, + const DefUseManager& rhs) { + bool same = true; + if (lhs.id_to_def_ != rhs.id_to_def_) { for (auto p : lhs.id_to_def_) { if (rhs.id_to_def_.find(p.first) == rhs.id_to_def_.end()) { - return false; + printf("Diff in id_to_def: missing value in rhs\n"); } } for (auto p : rhs.id_to_def_) { if (lhs.id_to_def_.find(p.first) == lhs.id_to_def_.end()) { - return false; + printf("Diff in id_to_def: missing value in lhs\n"); } } - return false; + same = false; } - if (lhs.id_to_users_ != rhs.id_to_users_) { - for (auto p : lhs.id_to_users_) { - if (rhs.id_to_users_.count(p) == 0) { - return false; - } + for (const auto& l : lhs.inst_to_used_id_) { + std::set<uint32_t> ul, ur; + lhs.ForEachUse(l.first, + [&ul](Instruction*, uint32_t id) { ul.insert(id); }); + rhs.ForEachUse(l.first, + [&ur](Instruction*, uint32_t id) { ur.insert(id); }); + if (ul.size() != ur.size()) { + printf( + "Diff in inst_to_used_id_: different number of used ids (%zu != %zu)", + ul.size(), ur.size()); + same = false; + } else if (ul != ur) { + printf("Diff in inst_to_used_id_: different used ids\n"); + same = false; } - for (auto p : rhs.id_to_users_) { - if (lhs.id_to_users_.count(p) == 0) { - return false; - } + } + for (const auto& r : rhs.inst_to_used_id_) { + auto iter_l = lhs.inst_to_used_id_.find(r.first); + if (r.second.empty() && + !(iter_l == lhs.inst_to_used_id_.end() || iter_l->second.empty())) { + printf("Diff in inst_to_used_id_: unexpected instr in rhs\n"); + same = false; } - return false; } - if (lhs.inst_to_used_ids_ != rhs.inst_to_used_ids_) { - for (auto p : lhs.inst_to_used_ids_) { - if (rhs.inst_to_used_ids_.count(p.first) == 0) { - return false; - } + for (const auto& l : lhs.inst_to_users_) { + std::set<Instruction*> ul, ur; + lhs.ForEachUser(l.first, [&ul](Instruction* use) { ul.insert(use); }); + rhs.ForEachUser(l.first, [&ur](Instruction* use) { ur.insert(use); }); + if (ul.size() != ur.size()) { + printf("Diff in inst_to_users_: different number of users (%zu != %zu)", + ul.size(), ur.size()); + same = false; + } else if (ul != ur) { + printf("Diff in inst_to_users_: different users\n"); + same = false; } - for (auto p : rhs.inst_to_used_ids_) { - if (lhs.inst_to_used_ids_.count(p.first) == 0) { - return false; - } + } + for (const auto& r : rhs.inst_to_users_) { + auto iter_l = lhs.inst_to_users_.find(r.first); + if (r.second.empty() && + !(iter_l == lhs.inst_to_users_.end() || iter_l->second.empty())) { + printf("Diff in inst_to_users_: unexpected instr in rhs\n"); + same = false; } - return false; } - return true; + return same; } } // namespace analysis diff --git a/source/opt/def_use_manager.h b/source/opt/def_use_manager.h index 0499e82b..cf6cbdf5 100644 --- a/source/opt/def_use_manager.h +++ b/source/opt/def_use_manager.h @@ -15,14 +15,13 @@ #ifndef SOURCE_OPT_DEF_USE_MANAGER_H_ #define SOURCE_OPT_DEF_USE_MANAGER_H_ -#include <list> #include <set> #include <unordered_map> -#include <utility> #include <vector> #include "source/opt/instruction.h" #include "source/opt/module.h" +#include "source/util/pooled_linked_list.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { @@ -51,59 +50,16 @@ inline bool operator<(const Use& lhs, const Use& rhs) { return lhs.operand_index < rhs.operand_index; } -// Definition and user pair. -// -// The first element of the pair is the definition. -// The second element of the pair is the user. -// -// Definition should never be null. User can be null, however, such an entry -// should be used only for searching (e.g. all users of a particular definition) -// and never stored in a container. -using UserEntry = std::pair<Instruction*, Instruction*>; - -// Orders UserEntry for use in associative containers (i.e. less than ordering). -// -// The definition of an UserEntry is treated as the major key and the users as -// the minor key so that all the users of a particular definition are -// consecutive in a container. -// -// A null user always compares less than a real user. This is done to provide -// easy values to search for the beginning of the users of a particular -// definition (i.e. using {def, nullptr}). -struct UserEntryLess { - bool operator()(const UserEntry& lhs, const UserEntry& rhs) const { - // If lhs.first and rhs.first are both null, fall through to checking the - // second entries. - if (!lhs.first && rhs.first) return true; - if (lhs.first && !rhs.first) return false; - - // If neither definition is null, then compare unique ids. - if (lhs.first && rhs.first) { - if (lhs.first->unique_id() < rhs.first->unique_id()) return true; - if (rhs.first->unique_id() < lhs.first->unique_id()) return false; - } - - // Return false on equality. - if (!lhs.second && !rhs.second) return false; - if (!lhs.second) return true; - if (!rhs.second) return false; - - // If neither user is null then compare unique ids. - return lhs.second->unique_id() < rhs.second->unique_id(); - } -}; - // A class for analyzing and managing defs and uses in an Module. class DefUseManager { public: using IdToDefMap = std::unordered_map<uint32_t, Instruction*>; - using IdToUsersMap = std::set<UserEntry, UserEntryLess>; // Constructs a def-use manager from the given |module|. All internal messages // will be communicated to the outside via the given message |consumer|. This // instance only keeps a reference to the |consumer|, so the |consumer| should // outlive this instance. - DefUseManager(Module* module) { AnalyzeDefUse(module); } + DefUseManager(Module* module) : DefUseManager() { AnalyzeDefUse(module); } DefUseManager(const DefUseManager&) = delete; DefUseManager(DefUseManager&&) = delete; @@ -191,14 +147,12 @@ class DefUseManager { // Returns the annotation instrunctions which are a direct use of the given // |id|. This means when the decorations are applied through decoration // group(s), this function will just return the OpGroupDecorate - // instrcution(s) which refer to the given id as an operand. The OpDecorate + // instruction(s) which refer to the given id as an operand. The OpDecorate // instructions which decorate the decoration group will not be returned. std::vector<Instruction*> GetAnnotations(uint32_t id) const; // Returns the map from ids to their def instructions. const IdToDefMap& id_to_defs() const { return id_to_def_; } - // Returns the map from instructions to their users. - const IdToUsersMap& id_to_users() const { return id_to_users_; } // Clear the internal def-use record of the given instruction |inst|. This // method will update the use information of the operand ids of |inst|. The @@ -210,43 +164,43 @@ class DefUseManager { // Erases the records that a given instruction uses its operand ids. void EraseUseRecordsOfOperandIds(const Instruction* inst); - friend bool operator==(const DefUseManager&, const DefUseManager&); - friend bool operator!=(const DefUseManager& lhs, const DefUseManager& rhs) { - return !(lhs == rhs); - } + friend bool CompareAndPrintDifferences(const DefUseManager&, + const DefUseManager&); - // If |inst| has not already been analysed, then analyses its defintion and + // If |inst| has not already been analysed, then analyses its definition and // uses. void UpdateDefUse(Instruction* inst); + // Compacts any internal storage to save memory. + void CompactStorage(); + private: - using InstToUsedIdsMap = - std::unordered_map<const Instruction*, std::vector<uint32_t>>; + using UseList = spvtools::utils::PooledLinkedList<Instruction*>; + using UseListPool = spvtools::utils::PooledLinkedListNodes<Instruction*>; + // Stores linked lists of Instructions using a def. + using InstToUsersMap = std::unordered_map<const Instruction*, UseList>; - // Returns the first location that {|def|, nullptr} could be inserted into the - // users map without violating ordering. - IdToUsersMap::const_iterator UsersBegin(const Instruction* def) const; + using UsedIdList = spvtools::utils::PooledLinkedList<uint32_t>; + using UsedIdListPool = spvtools::utils::PooledLinkedListNodes<uint32_t>; + // Stores mapping from instruction to their UsedIdRange. + using InstToUsedIdMap = std::unordered_map<const Instruction*, UsedIdList>; - // Returns true if |iter| has not reached the end of |def|'s users. - // - // In the first version |iter| is compared against the end of the map for - // validity before other checks. In the second version, |iter| is compared - // against |cached_end| for validity before other checks. This allows caching - // the map's end which is a performance improvement on some platforms. - bool UsersNotEnd(const IdToUsersMap::const_iterator& iter, - const Instruction* def) const; - bool UsersNotEnd(const IdToUsersMap::const_iterator& iter, - const IdToUsersMap::const_iterator& cached_end, - const Instruction* def) const; + DefUseManager(); // Analyzes the defs and uses in the given |module| and populates data // structures in this class. Does nothing if |module| is nullptr. void AnalyzeDefUse(Module* module); - IdToDefMap id_to_def_; // Mapping from ids to their definitions - IdToUsersMap id_to_users_; // Mapping from ids to their users - // Mapping from instructions to the ids used in the instruction. - InstToUsedIdsMap inst_to_used_ids_; + // Removes unused entries in used_records_ and used_ids_. + void CompactUseRecords(); + void CompactUsedIds(); + + IdToDefMap id_to_def_; // Mapping from ids to their definitions + InstToUsersMap inst_to_users_; // Map from def to uses. + std::unique_ptr<UseListPool> use_pool_; + + std::unique_ptr<UsedIdListPool> used_id_pool_; + InstToUsedIdMap inst_to_used_id_; // Map from instruction to used ids. }; } // namespace analysis diff --git a/source/opt/desc_sroa.cpp b/source/opt/desc_sroa.cpp index bcbdde94..b130ca80 100644 --- a/source/opt/desc_sroa.cpp +++ b/source/opt/desc_sroa.cpp @@ -118,7 +118,7 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var, if (use->NumInOperands() == 2) { // We are not indexing into the replacement variable. We can replaces the - // access chain with the replacement varibale itself. + // access chain with the replacement variable itself. context()->ReplaceAllUsesWith(use->result_id(), replacement_var); context()->KillInst(use); return true; @@ -135,8 +135,8 @@ bool DescriptorScalarReplacement::ReplaceAccessChain(Instruction* var, // Use the replacement variable as the base address. new_operands.push_back({SPV_OPERAND_TYPE_ID, {replacement_var}}); - // Drop the first index because it is consumed by the replacment, and copy the - // rest. + // Drop the first index because it is consumed by the replacement, and copy + // the rest. for (uint32_t i = 4; i < use->NumOperands(); i++) { new_operands.emplace_back(use->GetOperand(i)); } @@ -169,7 +169,7 @@ void DescriptorScalarReplacement::CopyDecorationsForNewVariable( Instruction* old_var, uint32_t index, uint32_t new_var_id, uint32_t new_var_ptr_type_id, const bool is_old_var_array, const bool is_old_var_struct, Instruction* old_var_type) { - // Handle OpDecorate instructions. + // Handle OpDecorate and OpDecorateString instructions. for (auto old_decoration : get_decoration_mgr()->GetDecorationsFor(old_var->result_id(), true)) { uint32_t new_binding = 0; @@ -212,7 +212,8 @@ uint32_t DescriptorScalarReplacement::GetNewBindingForElement( void DescriptorScalarReplacement::CreateNewDecorationForNewVariable( Instruction* old_decoration, uint32_t new_var_id, uint32_t new_binding) { - assert(old_decoration->opcode() == SpvOpDecorate); + assert(old_decoration->opcode() == SpvOpDecorate || + old_decoration->opcode() == SpvOpDecorateString); std::unique_ptr<Instruction> new_decoration(old_decoration->Clone(context())); new_decoration->SetInOperand(0, {new_var_id}); diff --git a/source/opt/desc_sroa.h b/source/opt/desc_sroa.h index fea06255..6a24fd87 100644 --- a/source/opt/desc_sroa.h +++ b/source/opt/desc_sroa.h @@ -115,10 +115,11 @@ class DescriptorScalarReplacement : public Pass { const bool is_old_var_struct, Instruction* old_var_type); - // Create a new OpDecorate instruction by cloning |old_decoration|. The new - // OpDecorate instruction will be used for a variable whose id is - // |new_var_ptr_type_id|. If |old_decoration| is a decoration for a binding, - // the new OpDecorate instruction will have |new_binding| as its binding. + // Create a new OpDecorate(String) instruction by cloning |old_decoration|. + // The new OpDecorate(String) instruction will be used for a variable whose id + // is |new_var_ptr_type_id|. If |old_decoration| is a decoration for a + // binding, the new OpDecorate(String) instruction will have |new_binding| as + // its binding. void CreateNewDecorationForNewVariable(Instruction* old_decoration, uint32_t new_var_id, uint32_t new_binding); @@ -131,7 +132,7 @@ class DescriptorScalarReplacement : public Pass { // A map from an OpVariable instruction to the set of variables that will be // used to replace it. The entry |replacement_variables_[var][i]| is the id of - // a variable that will be used in the place of the the ith element of the + // a variable that will be used in the place of the ith element of the // array |var|. If the entry is |0|, then the variable has not been // created yet. std::map<Instruction*, std::vector<uint32_t>> replacement_variables_; diff --git a/source/opt/dominator_tree.cpp b/source/opt/dominator_tree.cpp index 55287f44..d86de151 100644 --- a/source/opt/dominator_tree.cpp +++ b/source/opt/dominator_tree.cpp @@ -48,7 +48,7 @@ namespace { // BBType - BasicBlock type. Will either be BasicBlock or DominatorTreeNode // SuccessorLambda - Lamdba matching the signature of 'const // std::vector<BBType>*(const BBType *A)'. Will return a vector of the nodes -// succeding BasicBlock A. +// succeeding BasicBlock A. // PostLambda - Lamdba matching the signature of 'void (const BBType*)' will be // called on each node traversed AFTER their children. // PreLambda - Lamdba matching the signature of 'void (const BBType*)' will be @@ -69,7 +69,7 @@ static void DepthFirstSearch(const BBType* bb, SuccessorLambda successors, // BBType - BasicBlock type. Will either be BasicBlock or DominatorTreeNode // SuccessorLambda - Lamdba matching the signature of 'const // std::vector<BBType>*(const BBType *A)'. Will return a vector of the nodes -// succeding BasicBlock A. +// succeeding BasicBlock A. // PostLambda - Lamdba matching the signature of 'void (const BBType*)' will be // called on each node traversed after their children. template <typename BBType, typename SuccessorLambda, typename PostLambda> diff --git a/source/opt/dominator_tree.h b/source/opt/dominator_tree.h index 0024bc50..1674b228 100644 --- a/source/opt/dominator_tree.h +++ b/source/opt/dominator_tree.h @@ -278,7 +278,7 @@ class DominatorTree { private: // Wrapper function which gets the list of pairs of each BasicBlocks to its - // immediately dominating BasicBlock and stores the result in the the edges + // immediately dominating BasicBlock and stores the result in the edges // parameter. // // The |edges| vector will contain the dominator tree as pairs of nodes. diff --git a/source/opt/eliminate_dead_input_components_pass.cpp b/source/opt/eliminate_dead_input_components_pass.cpp new file mode 100644 index 00000000..f383136d --- /dev/null +++ b/source/opt/eliminate_dead_input_components_pass.cpp @@ -0,0 +1,146 @@ +// Copyright (c) 2022 The Khronos Group Inc. +// Copyright (c) 2022 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/eliminate_dead_input_components_pass.h" + +#include <set> +#include <vector> + +#include "source/opt/instruction.h" +#include "source/opt/ir_builder.h" +#include "source/opt/ir_context.h" +#include "source/util/bit_vector.h" + +namespace { + +const uint32_t kAccessChainBaseInIdx = 0; +const uint32_t kAccessChainIndex0InIdx = 1; +const uint32_t kConstantValueInIdx = 0; +const uint32_t kVariableStorageClassInIdx = 0; + +} // namespace + +namespace spvtools { +namespace opt { + +Pass::Status EliminateDeadInputComponentsPass::Process() { + // Current functionality assumes shader capability + if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) + return Status::SuccessWithoutChange; + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + bool modified = false; + std::vector<std::pair<Instruction*, unsigned>> arrays_to_change; + for (auto& var : context()->types_values()) { + if (var.opcode() != SpvOpVariable) { + continue; + } + analysis::Type* var_type = type_mgr->GetType(var.type_id()); + analysis::Pointer* ptr_type = var_type->AsPointer(); + if (ptr_type == nullptr) { + continue; + } + if (ptr_type->storage_class() != SpvStorageClassInput) { + continue; + } + const analysis::Array* arr_type = ptr_type->pointee_type()->AsArray(); + if (arr_type == nullptr) { + continue; + } + unsigned arr_len_id = arr_type->LengthId(); + Instruction* arr_len_inst = def_use_mgr->GetDef(arr_len_id); + if (arr_len_inst->opcode() != SpvOpConstant) { + continue; + } + // SPIR-V requires array size is >= 1, so this works for signed or + // unsigned size + unsigned original_max = + arr_len_inst->GetSingleWordInOperand(kConstantValueInIdx) - 1; + unsigned max_idx = FindMaxIndex(var, original_max); + if (max_idx != original_max) { + ChangeArrayLength(var, max_idx + 1); + modified = true; + } + } + + return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange; +} + +unsigned EliminateDeadInputComponentsPass::FindMaxIndex(Instruction& var, + unsigned original_max) { + unsigned max = 0; + bool seen_non_const_ac = false; + assert(var.opcode() == SpvOpVariable && "must be variable"); + context()->get_def_use_mgr()->WhileEachUser( + var.result_id(), [&max, &seen_non_const_ac, var, this](Instruction* use) { + auto use_opcode = use->opcode(); + if (use_opcode == SpvOpLoad || use_opcode == SpvOpCopyMemory || + use_opcode == SpvOpCopyMemorySized || + use_opcode == SpvOpCopyObject) { + seen_non_const_ac = true; + return false; + } + if (use->opcode() != SpvOpAccessChain && + use->opcode() != SpvOpInBoundsAccessChain) { + return true; + } + // OpAccessChain with no indices currently not optimized + if (use->NumInOperands() == 1) { + seen_non_const_ac = true; + return false; + } + unsigned base_id = use->GetSingleWordInOperand(kAccessChainBaseInIdx); + USE_ASSERT(base_id == var.result_id() && "unexpected base"); + unsigned idx_id = use->GetSingleWordInOperand(kAccessChainIndex0InIdx); + Instruction* idx_inst = context()->get_def_use_mgr()->GetDef(idx_id); + if (idx_inst->opcode() != SpvOpConstant) { + seen_non_const_ac = true; + return false; + } + unsigned value = idx_inst->GetSingleWordInOperand(kConstantValueInIdx); + if (value > max) max = value; + return true; + }); + return seen_non_const_ac ? original_max : max; +} + +void EliminateDeadInputComponentsPass::ChangeArrayLength(Instruction& arr, + unsigned length) { + analysis::TypeManager* type_mgr = context()->get_type_mgr(); + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + analysis::Pointer* ptr_type = type_mgr->GetType(arr.type_id())->AsPointer(); + const analysis::Array* arr_ty = ptr_type->pointee_type()->AsArray(); + assert(arr_ty && "expecting array type"); + uint32_t length_id = const_mgr->GetUIntConst(length); + analysis::Array new_arr_ty(arr_ty->element_type(), + arr_ty->GetConstantLengthInfo(length_id, length)); + analysis::Type* reg_new_arr_ty = type_mgr->GetRegisteredType(&new_arr_ty); + analysis::Pointer new_ptr_ty(reg_new_arr_ty, SpvStorageClassInput); + analysis::Type* reg_new_ptr_ty = type_mgr->GetRegisteredType(&new_ptr_ty); + uint32_t new_ptr_ty_id = type_mgr->GetTypeInstruction(reg_new_ptr_ty); + arr.SetResultType(new_ptr_ty_id); + def_use_mgr->AnalyzeInstUse(&arr); + // Move array OpVariable instruction after its new type to preserve order + USE_ASSERT(arr.GetSingleWordInOperand(kVariableStorageClassInIdx) != + SpvStorageClassFunction && + "cannot move Function variable"); + Instruction* new_ptr_ty_inst = def_use_mgr->GetDef(new_ptr_ty_id); + arr.RemoveFromList(); + arr.InsertAfter(new_ptr_ty_inst); +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/eliminate_dead_input_components_pass.h b/source/opt/eliminate_dead_input_components_pass.h new file mode 100644 index 00000000..b77857f4 --- /dev/null +++ b/source/opt/eliminate_dead_input_components_pass.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 The Khronos Group Inc. +// Copyright (c) 2022 LunarG Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_ELIMINATE_DEAD_INPUT_COMPONENTS_H_ +#define SOURCE_OPT_ELIMINATE_DEAD_INPUT_COMPONENTS_H_ + +#include <unordered_map> + +#include "source/opt/ir_context.h" +#include "source/opt/module.h" +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class EliminateDeadInputComponentsPass : public Pass { + public: + explicit EliminateDeadInputComponentsPass() {} + + const char* name() const override { return "reduce-load-size"; } + Status Process() override; + + // Return the mask of preserved Analyses. + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | + IRContext::kAnalysisInstrToBlockMapping | + IRContext::kAnalysisCombinators | IRContext::kAnalysisCFG | + IRContext::kAnalysisDominatorAnalysis | + IRContext::kAnalysisLoopAnalysis | IRContext::kAnalysisNameMap | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // Find the max constant used to index the variable declared by |var| + // through OpAccessChain or OpInBoundsAccessChain. If any non-constant + // indices or non-Op*AccessChain use of |var|, return |original_max|. + unsigned FindMaxIndex(Instruction& var, unsigned original_max); + + // Change the length of the array |inst| to |length| + void ChangeArrayLength(Instruction& inst, unsigned length); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_ELIMINATE_DEAD_INPUT_COMPONENTS_H_ diff --git a/source/opt/eliminate_dead_members_pass.cpp b/source/opt/eliminate_dead_members_pass.cpp index a24ba8f4..52aca525 100644 --- a/source/opt/eliminate_dead_members_pass.cpp +++ b/source/opt/eliminate_dead_members_pass.cpp @@ -38,7 +38,7 @@ Pass::Status EliminateDeadMembersPass::Process() { } void EliminateDeadMembersPass::FindLiveMembers() { - // Until we have implemented the rewritting of OpSpecConsantOp instructions, + // Until we have implemented the rewriting of OpSpecConsantOp instructions, // we have to mark them as fully used just to be safe. for (auto& inst : get_module()->types_values()) { if (inst.opcode() == SpvOpSpecConstantOp) { @@ -570,7 +570,7 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) { Instruction* type_inst = get_def_use_mgr()->GetDef(type_id); switch (type_inst->opcode()) { case SpvOpTypeStruct: - // The type will have already been rewriten, so use the new member + // The type will have already been rewritten, so use the new member // index. type_id = type_inst->GetSingleWordInOperand(new_member_idx); break; diff --git a/source/opt/feature_manager.cpp b/source/opt/feature_manager.cpp index 39a4a348..a5902716 100644 --- a/source/opt/feature_manager.cpp +++ b/source/opt/feature_manager.cpp @@ -39,8 +39,7 @@ void FeatureManager::AddExtension(Instruction* ext) { assert(ext->opcode() == SpvOpExtension && "Expecting an extension instruction."); - const std::string name = - reinterpret_cast<const char*>(ext->GetInOperand(0u).words.data()); + const std::string name = ext->GetInOperand(0u).AsString(); Extension extension; if (GetExtensionFromString(name.c_str(), &extension)) { extensions_.Add(extension); diff --git a/source/opt/fold.cpp b/source/opt/fold.cpp index 6550fb4f..b903da6a 100644 --- a/source/opt/fold.cpp +++ b/source/opt/fold.cpp @@ -540,7 +540,7 @@ std::vector<uint32_t> InstructionFolder::FoldVectors( // in 32-bit words here. The reason of not using FoldScalars() here // is that we do not create temporary null constants as components // when the vector operand is a NullConstant because Constant creation - // may need extra checks for the validity and that is not manageed in + // may need extra checks for the validity and that is not managed in // here. if (const analysis::ScalarConstant* scalar_component = vector_operand->GetComponents().at(d)->AsScalarConstant()) { diff --git a/source/opt/fold_spec_constant_op_and_composite_pass.cpp b/source/opt/fold_spec_constant_op_and_composite_pass.cpp index 8ab717ea..85f11fde 100644 --- a/source/opt/fold_spec_constant_op_and_composite_pass.cpp +++ b/source/opt/fold_spec_constant_op_and_composite_pass.cpp @@ -115,7 +115,7 @@ bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp( Instruction* folded_inst = nullptr; assert(inst->GetInOperand(0).type == SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER && - "The first in-operand of OpSpecContantOp instruction must be of " + "The first in-operand of OpSpecConstantOp instruction must be of " "SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type"); switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) { diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 4904f186..c879a0c5 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -2368,7 +2368,7 @@ FoldingRule VectorShuffleFeedingShuffle() { // fold. return false; } - } else { + } else if (component_index != undef_literal) { if (new_feeder_id == 0) { // First time through, save the id of the operand the element comes // from. @@ -2382,7 +2382,7 @@ FoldingRule VectorShuffleFeedingShuffle() { component_index -= feeder_op0_length; } - if (!feeder_is_op0) { + if (!feeder_is_op0 && component_index != undef_literal) { component_index += op0_length; } } diff --git a/source/opt/graphics_robust_access_pass.cpp b/source/opt/graphics_robust_access_pass.cpp index 1b28f9b5..4652d72d 100644 --- a/source/opt/graphics_robust_access_pass.cpp +++ b/source/opt/graphics_robust_access_pass.cpp @@ -13,7 +13,7 @@ // limitations under the License. // This pass injects code in a graphics shader to implement guarantees -// satisfying Vulkan's robustBufferAcces rules. Robust access rules permit +// satisfying Vulkan's robustBufferAccess rules. Robust access rules permit // an out-of-bounds access to be redirected to an access of the same type // (load, store, etc.) but within the same root object. // @@ -74,7 +74,7 @@ // Pointers are always (correctly) typed and so the address and number of // consecutive locations are fully determined by the pointer. // -// - A pointer value orginates as one of few cases: +// - A pointer value originates as one of few cases: // // - OpVariable for an interface object or an array of them: image, // buffer (UBO or SSBO), sampler, sampled-image, push-constant, input @@ -559,21 +559,17 @@ uint32_t GraphicsRobustAccessPass::GetGlslInsts() { if (module_status_.glsl_insts_id == 0) { // This string serves double-duty as raw data for a string and for a vector // of 32-bit words - const char glsl[] = "GLSL.std.450\0\0\0\0"; - const size_t glsl_str_byte_len = 16; + const char glsl[] = "GLSL.std.450"; // Use an existing import if we can. for (auto& inst : context()->module()->ext_inst_imports()) { - const auto& name_words = inst.GetInOperand(0).words; - if (0 == std::strncmp(reinterpret_cast<const char*>(name_words.data()), - glsl, glsl_str_byte_len)) { + if (inst.GetInOperand(0).AsString() == glsl) { module_status_.glsl_insts_id = inst.result_id(); } } if (module_status_.glsl_insts_id == 0) { // Make a new import instruction. module_status_.glsl_insts_id = TakeNextId(); - std::vector<uint32_t> words(glsl_str_byte_len / sizeof(uint32_t)); - std::memcpy(words.data(), glsl, glsl_str_byte_len); + std::vector<uint32_t> words = spvtools::utils::MakeVector(glsl); auto import_inst = MakeUnique<Instruction>( context(), SpvOpExtInstImport, 0, module_status_.glsl_insts_id, std::initializer_list<Operand>{ @@ -962,7 +958,7 @@ spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer( constant_mgr->GetDefiningInstruction(component_0)->result_id(); // If the image is a cube array, then the last component of the queried - // size is the layer count. In the query, we have to accomodate folding + // size is the layer count. In the query, we have to accommodate folding // in the face index ranging from 0 through 5. The inclusive upper bound // on the third coordinate therefore is multiplied by 6. auto* query_size_including_faces = query_size; diff --git a/source/opt/graphics_robust_access_pass.h b/source/opt/graphics_robust_access_pass.h index 6fc692c1..8f4c9dc7 100644 --- a/source/opt/graphics_robust_access_pass.h +++ b/source/opt/graphics_robust_access_pass.h @@ -111,7 +111,7 @@ class GraphicsRobustAccessPass : public Pass { Instruction* max, Instruction* where); // Returns a new instruction which evaluates to the length the runtime array - // referenced by the access chain at the specfied index. The instruction is + // referenced by the access chain at the specified index. The instruction is // inserted before the access chain instruction. Returns a null pointer in // some cases if assumptions are violated (rather than asserting out). opt::Instruction* MakeRuntimeArrayLengthInst(Instruction* access_chain, diff --git a/source/opt/inst_bindless_check_pass.cpp b/source/opt/inst_bindless_check_pass.cpp index 5607239a..c2c5d6cb 100644 --- a/source/opt/inst_bindless_check_pass.cpp +++ b/source/opt/inst_bindless_check_pass.cpp @@ -39,13 +39,6 @@ static const int kSpvTypeImageMS = 4; static const int kSpvTypeImageSampled = 5; } // anonymous namespace -// Avoid unused variable warning/error on Linux -#ifndef NDEBUG -#define USE_ASSERT(x) assert(x) -#else -#define USE_ASSERT(x) ((void)(x)) -#endif - namespace spvtools { namespace opt { diff --git a/source/opt/inst_bindless_check_pass.h b/source/opt/inst_bindless_check_pass.h index cd961805..e6e6ef4f 100644 --- a/source/opt/inst_bindless_check_pass.h +++ b/source/opt/inst_bindless_check_pass.h @@ -147,11 +147,11 @@ class InstBindlessCheckPass : public InstrumentPass { uint32_t GenLastByteIdx(RefAnalysis* ref, InstructionBuilder* builder); // Clone original image computation starting at |image_id| into |builder|. - // This may generate more than one instruction if neccessary. + // This may generate more than one instruction if necessary. uint32_t CloneOriginalImage(uint32_t image_id, InstructionBuilder* builder); // Clone original original reference encapsulated by |ref| into |builder|. - // This may generate more than one instruction if neccessary. + // This may generate more than one instruction if necessary. uint32_t CloneOriginalReference(RefAnalysis* ref, InstructionBuilder* builder); diff --git a/source/opt/inst_debug_printf_pass.cpp b/source/opt/inst_debug_printf_pass.cpp index c0e6bc3f..4218138f 100644 --- a/source/opt/inst_debug_printf_pass.cpp +++ b/source/opt/inst_debug_printf_pass.cpp @@ -16,6 +16,7 @@ #include "inst_debug_printf_pass.h" +#include "source/util/string_utils.h" #include "spirv/unified1/NonSemanticDebugPrintf.h" namespace spvtools { @@ -231,10 +232,8 @@ Pass::Status InstDebugPrintfPass::ProcessImpl() { bool non_sem_set_seen = false; for (auto c_itr = context()->module()->ext_inst_import_begin(); c_itr != context()->module()->ext_inst_import_end(); ++c_itr) { - const char* set_name = - reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]); - const char* non_sem_str = "NonSemantic."; - if (!strncmp(set_name, non_sem_str, strlen(non_sem_str))) { + const std::string set_name = c_itr->GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(set_name, "NonSemantic.")) { non_sem_set_seen = true; break; } @@ -242,9 +241,8 @@ Pass::Status InstDebugPrintfPass::ProcessImpl() { if (!non_sem_set_seen) { for (auto c_itr = context()->module()->extension_begin(); c_itr != context()->module()->extension_end(); ++c_itr) { - const char* ext_name = - reinterpret_cast<const char*>(&c_itr->GetInOperand(0).words[0]); - if (!strcmp(ext_name, "SPV_KHR_non_semantic_info")) { + const std::string ext_name = c_itr->GetInOperand(0).AsString(); + if (ext_name == "SPV_KHR_non_semantic_info") { context()->KillInst(&*c_itr); break; } diff --git a/source/opt/instruction.cpp b/source/opt/instruction.cpp index 2461e41e..418f1213 100644 --- a/source/opt/instruction.cpp +++ b/source/opt/instruction.cpp @@ -76,10 +76,9 @@ Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst, dbg_scope_(kNoDebugScope, kNoInlinedAt) { for (uint32_t i = 0; i < inst.num_operands; ++i) { const auto& current_payload = inst.operands[i]; - std::vector<uint32_t> words( - inst.words + current_payload.offset, + operands_.emplace_back( + current_payload.type, inst.words + current_payload.offset, inst.words + current_payload.offset + current_payload.num_words); - operands_.emplace_back(current_payload.type, std::move(words)); } assert((!IsLineInst() || dbg_line.empty()) && "Op(No)Line attaching to Op(No)Line found"); @@ -96,10 +95,9 @@ Instruction::Instruction(IRContext* c, const spv_parsed_instruction_t& inst, dbg_scope_(dbg_scope) { for (uint32_t i = 0; i < inst.num_operands; ++i) { const auto& current_payload = inst.operands[i]; - std::vector<uint32_t> words( - inst.words + current_payload.offset, + operands_.emplace_back( + current_payload.type, inst.words + current_payload.offset, inst.words + current_payload.offset + current_payload.num_words); - operands_.emplace_back(current_payload.type, std::move(words)); } } diff --git a/source/opt/instruction.h b/source/opt/instruction.h index ce568f66..2163d99b 100644 --- a/source/opt/instruction.h +++ b/source/opt/instruction.h @@ -24,6 +24,7 @@ #include "NonSemanticShaderDebugInfo100.h" #include "OpenCLDebugInfo100.h" +#include "source/binary.h" #include "source/common_debug_info.h" #include "source/latest_version_glsl_std_450_header.h" #include "source/latest_version_spirv_header.h" @@ -32,6 +33,7 @@ #include "source/opt/reflect.h" #include "source/util/ilist_node.h" #include "source/util/small_vector.h" +#include "source/util/string_utils.h" #include "spirv-tools/libspirv.h" const uint32_t kNoDebugScope = 0; @@ -82,21 +84,32 @@ struct Operand { Operand(spv_operand_type_t t, const OperandData& w) : type(t), words(w) {} + template <class InputIt> + Operand(spv_operand_type_t t, InputIt firstOperandData, + InputIt lastOperandData) + : type(t), words(firstOperandData, lastOperandData) {} + spv_operand_type_t type; // Type of this logical operand. OperandData words; // Binary segments of this logical operand. - // Returns a string operand as a C-style string. - const char* AsCString() const { - assert(type == SPV_OPERAND_TYPE_LITERAL_STRING); - return reinterpret_cast<const char*>(words.data()); + uint32_t AsId() const { + assert(spvIsIdType(type)); + assert(words.size() == 1); + return words[0]; } // Returns a string operand as a std::string. - std::string AsString() const { return AsCString(); } + std::string AsString() const { + assert(type == SPV_OPERAND_TYPE_LITERAL_STRING); + return spvtools::utils::MakeString(words); + } // Returns a literal integer operand as a uint64_t uint64_t AsLiteralUint64() const { - assert(type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER); + assert(type == SPV_OPERAND_TYPE_LITERAL_INTEGER || + type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER || + type == SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER || + type == SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER); assert(1 <= words.size()); assert(words.size() <= 2); uint64_t result = 0; @@ -123,7 +136,7 @@ inline bool operator!=(const Operand& o1, const Operand& o2) { } // This structure is used to represent a DebugScope instruction from -// the OpenCL.100.DebugInfo extened instruction set. Note that we can +// the OpenCL.100.DebugInfo extended instruction set. Note that we can // ignore the result id of DebugScope instruction because it is not // used for anything. We do not keep it to reduce the size of // structure. @@ -295,6 +308,7 @@ class Instruction : public utils::IntrusiveNodeBase<Instruction> { inline void SetInOperands(OperandList&& new_operands); // Sets the result type id. inline void SetResultType(uint32_t ty_id); + inline bool HasResultType() const { return has_type_id_; } // Sets the result id inline void SetResultId(uint32_t res_id); inline bool HasResultId() const { return has_result_id_; } diff --git a/source/opt/instrument_pass.h b/source/opt/instrument_pass.h index 12b939d4..90c1dd47 100644 --- a/source/opt/instrument_pass.h +++ b/source/opt/instrument_pass.h @@ -50,7 +50,7 @@ // A validation pass may read or write multiple buffers. All such buffers // are located in a single debug descriptor set whose index is passed at the // creation of the instrumentation pass. The bindings of the buffers used by -// a validation pass are permanantly assigned and fixed and documented by +// a validation pass are permanently assigned and fixed and documented by // the kDebugOutput* static consts. namespace spvtools { @@ -179,8 +179,8 @@ class InstrumentPass : public Pass { // the error. Every stage will write a fixed number of words. Vertex shaders // will write the Vertex and Instance ID. Fragment shaders will write // FragCoord.xy. Compute shaders will write the GlobalInvocation ID. - // The tesselation eval shader will write the Primitive ID and TessCoords.uv. - // The tesselation control shader and geometry shader will write the + // The tessellation eval shader will write the Primitive ID and TessCoords.uv. + // The tessellation control shader and geometry shader will write the // Primitive ID and Invocation ID. // // The Validation Error Code specifies the exact error which has occurred. diff --git a/source/opt/interp_fixup_pass.cpp b/source/opt/interp_fixup_pass.cpp index ad29e6a7..e8cdd99f 100644 --- a/source/opt/interp_fixup_pass.cpp +++ b/source/opt/interp_fixup_pass.cpp @@ -31,13 +31,6 @@ namespace { // Input Operand Indices static const int kSpvVariableStorageClassInIdx = 0; -// Avoid unused variable warning/error on Linux -#ifndef NDEBUG -#define USE_ASSERT(x) assert(x) -#else -#define USE_ASSERT(x) ((void)(x)) -#endif - // Folding rule function which attempts to replace |op(OpLoad(a),...)| // by |op(a,...)|, where |op| is one of the GLSLstd450 InterpolateAt* // instructions. Returns true if replaced, false otherwise. diff --git a/source/opt/ir_context.cpp b/source/opt/ir_context.cpp index 612a831a..a80d4f2d 100644 --- a/source/opt/ir_context.cpp +++ b/source/opt/ir_context.cpp @@ -41,6 +41,8 @@ namespace spvtools { namespace opt { void IRContext::BuildInvalidAnalyses(IRContext::Analysis set) { + set = Analysis(set & ~valid_analyses_); + if (set & kAnalysisDefUse) { BuildDefUseManager(); } @@ -106,7 +108,7 @@ void IRContext::InvalidateAnalyses(IRContext::Analysis analyses_to_invalidate) { analyses_to_invalidate |= kAnalysisDebugInfo; } - // The dominator analysis hold the psuedo entry and exit nodes from the CFG. + // The dominator analysis hold the pseudo entry and exit nodes from the CFG. // Also if the CFG change the dominators many changed as well, so the // dominator analysis should be invalidated as well. if (analyses_to_invalidate & kAnalysisCFG) { @@ -317,7 +319,7 @@ bool IRContext::IsConsistent() { #else if (AreAnalysesValid(kAnalysisDefUse)) { analysis::DefUseManager new_def_use(module()); - if (*get_def_use_mgr() != new_def_use) { + if (!CompareAndPrintDifferences(*get_def_use_mgr(), new_def_use)) { return false; } } @@ -623,9 +625,8 @@ void IRContext::AddCombinatorsForCapability(uint32_t capability) { void IRContext::AddCombinatorsForExtension(Instruction* extension) { assert(extension->opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast<const char*>(&extension->GetInOperand(0).words[0]); - if (!strcmp(extension_name, "GLSL.std.450")) { + const std::string extension_name = extension->GetInOperand(0).AsString(); + if (extension_name == "GLSL.std.450") { combinator_ops_[extension->result_id()] = {GLSLstd450Round, GLSLstd450RoundEven, GLSLstd450Trunc, @@ -944,11 +945,11 @@ void IRContext::EmitErrorMessage(std::string message, Instruction* inst) { uint32_t line_number = 0; uint32_t col_number = 0; - char* source = nullptr; + std::string source; if (line_inst != nullptr) { Instruction* file_name = get_def_use_mgr()->GetDef(line_inst->GetSingleWordInOperand(0)); - source = reinterpret_cast<char*>(&file_name->GetInOperand(0).words[0]); + source = file_name->GetInOperand(0).AsString(); // Get the line number and column number. line_number = line_inst->GetSingleWordInOperand(1); @@ -957,7 +958,7 @@ void IRContext::EmitErrorMessage(std::string message, Instruction* inst) { message += "\n " + inst->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); - consumer()(SPV_MSG_ERROR, source, {line_number, col_number, 0}, + consumer()(SPV_MSG_ERROR, source.c_str(), {line_number, col_number, 0}, message.c_str()); } diff --git a/source/opt/ir_context.h b/source/opt/ir_context.h index 65853476..946f9e9d 100644 --- a/source/opt/ir_context.h +++ b/source/opt/ir_context.h @@ -43,6 +43,7 @@ #include "source/opt/type_manager.h" #include "source/opt/value_number_table.h" #include "source/util/make_unique.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -301,7 +302,7 @@ class IRContext { } } - // Returns a pointer the decoration manager. If the decoration manger is + // Returns a pointer the decoration manager. If the decoration manager is // invalid, it is rebuilt first. analysis::DecorationManager* get_decoration_mgr() { if (!AreAnalysesValid(kAnalysisDecorations)) { @@ -384,7 +385,7 @@ class IRContext { // Deletes the instruction defining the given |id|. Returns true on // success, false if the given |id| is not defined at all. This method also - // erases the name, decorations, and defintion of |id|. + // erases the name, decorations, and definition of |id|. // // Pointers and iterators pointing to the deleted instructions become invalid. // However other pointers and iterators are still valid. @@ -518,6 +519,18 @@ class IRContext { std::string message = "ID overflow. Try running compact-ids."; consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); } +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + // If TakeNextId returns 0, it is very likely that execution will + // subsequently fail. Such failures are false alarms from a fuzzing point + // of view: they are due to the fact that too many ids were used, rather + // than being due to an actual bug. Thus, during a fuzzing build, it is + // preferable to bail out when ID overflow occurs. + // + // A zero exit code is returned here because a non-zero code would cause + // ClusterFuzz/OSS-Fuzz to regard the termination as a crash, and spurious + // crash reports is what this guard aims to avoid. + exit(0); +#endif } return next_id; } @@ -789,7 +802,7 @@ class IRContext { // iterators to traverse instructions. std::unordered_map<uint32_t, Function*> id_to_func_; - // A bitset indicating which analyes are currently valid. + // A bitset indicating which analyzes are currently valid. Analysis valid_analyses_; // Opcodes of shader capability core executable instructions @@ -854,8 +867,7 @@ inline IRContext::Analysis operator|(IRContext::Analysis lhs, inline IRContext::Analysis& operator|=(IRContext::Analysis& lhs, IRContext::Analysis rhs) { - lhs = static_cast<IRContext::Analysis>(static_cast<int>(lhs) | - static_cast<int>(rhs)); + lhs = lhs | rhs; return lhs; } @@ -1020,11 +1032,7 @@ void IRContext::AddCapability(std::unique_ptr<Instruction>&& c) { } void IRContext::AddExtension(const std::string& ext_name) { - const auto num_chars = ext_name.size(); - // Compute num words, accommodate the terminating null character. - const auto num_words = (num_chars + 1 + 3) / 4; - std::vector<uint32_t> ext_words(num_words, 0u); - std::memcpy(ext_words.data(), ext_name.data(), num_chars); + std::vector<uint32_t> ext_words = spvtools::utils::MakeVector(ext_name); AddExtension(std::unique_ptr<Instruction>( new Instruction(this, SpvOpExtension, 0u, 0u, {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}}))); @@ -1041,11 +1049,7 @@ void IRContext::AddExtension(std::unique_ptr<Instruction>&& e) { } void IRContext::AddExtInstImport(const std::string& name) { - const auto num_chars = name.size(); - // Compute num words, accommodate the terminating null character. - const auto num_words = (num_chars + 1 + 3) / 4; - std::vector<uint32_t> ext_words(num_words, 0u); - std::memcpy(ext_words.data(), name.data(), num_chars); + std::vector<uint32_t> ext_words = spvtools::utils::MakeVector(name); AddExtInstImport(std::unique_ptr<Instruction>( new Instruction(this, SpvOpExtInstImport, 0u, TakeNextId(), {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}}))); diff --git a/source/opt/ir_loader.cpp b/source/opt/ir_loader.cpp index a82b530e..97db9d8f 100644 --- a/source/opt/ir_loader.cpp +++ b/source/opt/ir_loader.cpp @@ -189,7 +189,8 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) { module_->SetMemoryModel(std::move(spv_inst)); } else if (opcode == SpvOpEntryPoint) { module_->AddEntryPoint(std::move(spv_inst)); - } else if (opcode == SpvOpExecutionMode) { + } else if (opcode == SpvOpExecutionMode || + opcode == SpvOpExecutionModeId) { module_->AddExecutionMode(std::move(spv_inst)); } else if (IsDebug1Inst(opcode)) { module_->AddDebug1Inst(std::move(spv_inst)); diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index da9ba8cc..0c6d0c24 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -19,6 +19,7 @@ #include "ir_builder.h" #include "ir_context.h" #include "iterator.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -328,8 +329,7 @@ bool LocalAccessChainConvertPass::AllExtensionsSupported() const { return false; // If any extension not in allowlist, return false for (auto& ei : get_module()->extensions()) { - const char* extName = - reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]); + const std::string extName = ei.GetInOperand(0).AsString(); if (extensions_allowlist_.find(extName) == extensions_allowlist_.end()) return false; } @@ -339,11 +339,9 @@ bool LocalAccessChainConvertPass::AllExtensionsSupported() const { for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12) && - 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100", - 32)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.") && + extension_name != "NonSemantic.Shader.DebugInfo.100") { return false; } } @@ -436,6 +434,7 @@ void LocalAccessChainConvertPass::InitExtensions() { "SPV_KHR_integer_dot_product", "SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info", + "SPV_KHR_uniform_group_instructions", }); } diff --git a/source/opt/local_access_chain_convert_pass.h b/source/opt/local_access_chain_convert_pass.h index 552062e5..a51660f1 100644 --- a/source/opt/local_access_chain_convert_pass.h +++ b/source/opt/local_access_chain_convert_pass.h @@ -81,7 +81,7 @@ class LocalAccessChainConvertPass : public MemPass { std::vector<Operand>* in_opnds); // Create a load/insert/store equivalent to a store of - // |valId| through (constant index) access chaing |ptrInst|. + // |valId| through (constant index) access chain |ptrInst|. // Append to |newInsts|. Returns true if successful. bool GenAccessChainStoreReplacement( const Instruction* ptrInst, uint32_t valId, diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp index 5fd4f658..33c8bdf8 100644 --- a/source/opt/local_single_block_elim_pass.cpp +++ b/source/opt/local_single_block_elim_pass.cpp @@ -19,6 +19,7 @@ #include <vector> #include "source/opt/iterator.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -183,8 +184,7 @@ void LocalSingleBlockLoadStoreElimPass::Initialize() { bool LocalSingleBlockLoadStoreElimPass::AllExtensionsSupported() const { // If any extension not in allowlist, return false for (auto& ei : get_module()->extensions()) { - const char* extName = - reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]); + const std::string extName = ei.GetInOperand(0).AsString(); if (extensions_allowlist_.find(extName) == extensions_allowlist_.end()) return false; } @@ -194,11 +194,9 @@ bool LocalSingleBlockLoadStoreElimPass::AllExtensionsSupported() const { for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12) && - 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100", - 32)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.") && + extension_name != "NonSemantic.Shader.DebugInfo.100") { return false; } } @@ -288,6 +286,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() { "SPV_KHR_integer_dot_product", "SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info", + "SPV_KHR_uniform_group_instructions", }); } diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp index 051bcada..f22b1911 100644 --- a/source/opt/local_single_store_elim_pass.cpp +++ b/source/opt/local_single_store_elim_pass.cpp @@ -19,6 +19,7 @@ #include "source/cfa.h" #include "source/latest_version_glsl_std_450_header.h" #include "source/opt/iterator.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -48,8 +49,7 @@ bool LocalSingleStoreElimPass::LocalSingleStoreElim(Function* func) { bool LocalSingleStoreElimPass::AllExtensionsSupported() const { // If any extension not in allowlist, return false for (auto& ei : get_module()->extensions()) { - const char* extName = - reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]); + const std::string extName = ei.GetInOperand(0).AsString(); if (extensions_allowlist_.find(extName) == extensions_allowlist_.end()) return false; } @@ -59,11 +59,9 @@ bool LocalSingleStoreElimPass::AllExtensionsSupported() const { for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12) && - 0 != std::strncmp(extension_name, "NonSemantic.Shader.DebugInfo.100", - 32)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.") && + extension_name != "NonSemantic.Shader.DebugInfo.100") { return false; } } @@ -141,6 +139,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() { "SPV_KHR_integer_dot_product", "SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info", + "SPV_KHR_uniform_group_instructions", }); } bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) { diff --git a/source/opt/loop_descriptor.cpp b/source/opt/loop_descriptor.cpp index b5b56309..9bc495e5 100644 --- a/source/opt/loop_descriptor.cpp +++ b/source/opt/loop_descriptor.cpp @@ -719,7 +719,7 @@ bool Loop::FindNumberOfIterations(const Instruction* induction, step_value = -step_value; } - // Find the inital value of the loop and make sure it is a constant integer. + // Find the initial value of the loop and make sure it is a constant integer. int64_t init_value = 0; if (!GetInductionInitValue(induction, &init_value)) return false; @@ -751,7 +751,7 @@ bool Loop::FindNumberOfIterations(const Instruction* induction, // We retrieve the number of iterations using the following formula, diff / // |step_value| where diff is calculated differently according to the // |condition| and uses the |condition_value| and |init_value|. If diff / -// |step_value| is NOT cleanly divisable then we add one to the sum. +// |step_value| is NOT cleanly divisible then we add one to the sum. int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value, int64_t init_value, int64_t step_value) const { int64_t diff = 0; @@ -795,7 +795,7 @@ int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value, // If the condition is not met to begin with the loop will never iterate. if (!(init_value >= condition_value)) return 0; - // We subract one to make it the same as SpvOpGreaterThan as it is + // We subtract one to make it the same as SpvOpGreaterThan as it is // functionally equivalent. diff = init_value - (condition_value - 1); diff --git a/source/opt/loop_descriptor.h b/source/opt/loop_descriptor.h index 4b4f8bc7..e88ff936 100644 --- a/source/opt/loop_descriptor.h +++ b/source/opt/loop_descriptor.h @@ -395,7 +395,7 @@ class Loop { // Sets |merge| as the loop merge block. No checks are performed here. inline void SetMergeBlockImpl(BasicBlock* merge) { loop_merge_ = merge; } - // Each differnt loop |condition| affects how we calculate the number of + // Each different loop |condition| affects how we calculate the number of // iterations using the |condition_value|, |init_value|, and |step_values| of // the induction variable. This method will return the number of iterations in // a loop with those values for a given |condition|. diff --git a/source/opt/loop_fission.cpp b/source/opt/loop_fission.cpp index 0678113c..b4df8c62 100644 --- a/source/opt/loop_fission.cpp +++ b/source/opt/loop_fission.cpp @@ -29,7 +29,7 @@ // 2 - For each loop in the list, group each instruction into a set of related // instructions by traversing each instructions users and operands recursively. // We stop if we encounter an instruction we have seen before or an instruction -// which we don't consider relevent (i.e OpLoopMerge). We then group these +// which we don't consider relevant (i.e OpLoopMerge). We then group these // groups into two different sets, one for the first loop and one for the // second. // @@ -453,7 +453,7 @@ Pass::Status LoopFissionPass::Process() { for (Function& f : *context()->module()) { // We collect all the inner most loops in the function and run the loop // splitting util on each. The reason we do this is to allow us to iterate - // over each, as creating new loops will invalidate the the loop iterator. + // over each, as creating new loops will invalidate the loop iterator. std::vector<Loop*> inner_most_loops{}; LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(&f); for (Loop& loop : loop_descriptor) { diff --git a/source/opt/loop_fission.h b/source/opt/loop_fission.h index e7a59c18..9bc12c0f 100644 --- a/source/opt/loop_fission.h +++ b/source/opt/loop_fission.h @@ -33,7 +33,7 @@ namespace opt { class LoopFissionPass : public Pass { public: - // Fuction used to determine if a given loop should be split. Takes register + // Function used to determine if a given loop should be split. Takes register // pressure region for that loop as a parameter and returns true if the loop // should be split. using FissionCriteriaFunction = diff --git a/source/opt/loop_fusion.cpp b/source/opt/loop_fusion.cpp index 07d171a0..f3aab283 100644 --- a/source/opt/loop_fusion.cpp +++ b/source/opt/loop_fusion.cpp @@ -165,7 +165,7 @@ bool LoopFusion::AreCompatible() { // Check adjacency, |loop_0_| should come just before |loop_1_|. // There is always at least one block between loops, even if it's empty. - // We'll check at most 2 preceeding blocks. + // We'll check at most 2 preceding blocks. auto pre_header_1 = loop_1_->GetPreHeaderBlock(); @@ -712,7 +712,7 @@ void LoopFusion::Fuse() { ld->RemoveLoop(loop_1_); - // Kill unnessecary instructions and remove all empty blocks. + // Kill unnecessary instructions and remove all empty blocks. for (auto inst : instr_to_delete) { context_->KillInst(inst); } diff --git a/source/opt/loop_fusion.h b/source/opt/loop_fusion.h index d61d6783..769da5f1 100644 --- a/source/opt/loop_fusion.h +++ b/source/opt/loop_fusion.h @@ -40,7 +40,7 @@ class LoopFusion { // That means: // * they both have one induction variable // * they have the same upper and lower bounds - // - same inital value + // - same initial value // - same condition // * they have the same update step // * they are adjacent, with |loop_0| appearing before |loop_1| diff --git a/source/opt/loop_fusion_pass.h b/source/opt/loop_fusion_pass.h index 3a0be600..9d5b7ccd 100644 --- a/source/opt/loop_fusion_pass.h +++ b/source/opt/loop_fusion_pass.h @@ -33,7 +33,7 @@ class LoopFusionPass : public Pass { // Processes the given |module|. Returns Status::Failure if errors occur when // processing. Returns the corresponding Status::Success if processing is - // succesful to indicate whether changes have been made to the modue. + // successful to indicate whether changes have been made to the module. Status Process() override; private: diff --git a/source/opt/loop_peeling.h b/source/opt/loop_peeling.h index 413f896f..2a55fe44 100644 --- a/source/opt/loop_peeling.h +++ b/source/opt/loop_peeling.h @@ -261,7 +261,7 @@ class LoopPeelingPass : public Pass { // Processes the given |module|. Returns Status::Failure if errors occur when // processing. Returns the corresponding Status::Success if processing is - // succesful to indicate whether changes have been made to the modue. + // successful to indicate whether changes have been made to the module. Pass::Status Process() override; private: diff --git a/source/opt/loop_unroller.cpp b/source/opt/loop_unroller.cpp index aff191fe..28ff0729 100644 --- a/source/opt/loop_unroller.cpp +++ b/source/opt/loop_unroller.cpp @@ -163,7 +163,7 @@ struct LoopUnrollState { }; // This class implements the actual unrolling. It uses a LoopUnrollState to -// maintain the state of the unrolling inbetween steps. +// maintain the state of the unrolling in between steps. class LoopUnrollerUtilsImpl { public: using BasicBlockListTy = std::vector<std::unique_ptr<BasicBlock>>; @@ -209,7 +209,7 @@ class LoopUnrollerUtilsImpl { // Add all blocks_to_add_ to function_ at the |insert_point|. void AddBlocksToFunction(const BasicBlock* insert_point); - // Duplicates the |old_loop|, cloning each body and remaping the ids without + // Duplicates the |old_loop|, cloning each body and remapping the ids without // removing instructions or changing relative structure. Result will be stored // in |new_loop|. void DuplicateLoop(Loop* old_loop, Loop* new_loop); @@ -241,7 +241,7 @@ class LoopUnrollerUtilsImpl { // Remap all the in |basic_block| to new IDs and keep the mapping of new ids // to old // ids. |loop| is used to identify special loop blocks (header, continue, - // ect). + // etc). void AssignNewResultIds(BasicBlock* basic_block); // Using the map built by AssignNewResultIds, replace the uses in |inst| @@ -320,7 +320,7 @@ class LoopUnrollerUtilsImpl { // and then be remapped at the end. std::vector<Instruction*> loop_phi_instructions_; - // The number of loop iterations that the loop would preform pre-unroll. + // The number of loop iterations that the loop would perform pre-unroll. size_t number_of_loop_iterations_; // The amount that the loop steps each iteration. @@ -839,7 +839,7 @@ void LoopUnrollerUtilsImpl::DuplicateLoop(Loop* old_loop, Loop* new_loop) { new_loop->SetMergeBlock(new_merge); } -// Whenever the utility copies a block it stores it in a tempory buffer, this +// Whenever the utility copies a block it stores it in a temporary buffer, this // function adds the buffer into the Function. The blocks will be inserted // after the block |insert_point|. void LoopUnrollerUtilsImpl::AddBlocksToFunction( diff --git a/source/opt/loop_unswitch_pass.cpp b/source/opt/loop_unswitch_pass.cpp index d805ecf3..1ee7e5e2 100644 --- a/source/opt/loop_unswitch_pass.cpp +++ b/source/opt/loop_unswitch_pass.cpp @@ -118,7 +118,7 @@ class LoopUnswitch { // Find a value that can be used to select the default path. // If none are possible, then it will just use 0. The value does not matter - // because this path will never be taken becaues the new switch outside of + // because this path will never be taken because the new switch outside of // the loop cannot select this path either. std::vector<uint32_t> existing_values; for (uint32_t i = 2; i < switch_inst->NumInOperands(); i += 2) { diff --git a/source/opt/loop_unswitch_pass.h b/source/opt/loop_unswitch_pass.h index 3ecdd611..4f7295d4 100644 --- a/source/opt/loop_unswitch_pass.h +++ b/source/opt/loop_unswitch_pass.h @@ -30,7 +30,7 @@ class LoopUnswitchPass : public Pass { // Processes the given |module|. Returns Status::Failure if errors occur when // processing. Returns the corresponding Status::Success if processing is - // succesful to indicate whether changes have been made to the modue. + // successful to indicate whether changes have been made to the module. Pass::Status Process() override; private: diff --git a/source/opt/loop_utils.h b/source/opt/loop_utils.h index a4e61900..70060fc4 100644 --- a/source/opt/loop_utils.h +++ b/source/opt/loop_utils.h @@ -123,7 +123,7 @@ class LoopUtils { // Clone the |loop_| and make the new loop branch to the second loop on exit. Loop* CloneAndAttachLoopToHeader(LoopCloningResult* cloning_result); - // Perfom a partial unroll of |loop| by given |factor|. This will copy the + // Perform a partial unroll of |loop| by given |factor|. This will copy the // body of the loop |factor| times. So a |factor| of one would give a new loop // with the original body plus one unrolled copy body. bool PartiallyUnroll(size_t factor); @@ -139,7 +139,7 @@ class LoopUtils { // 1. That the loop is in structured order. // 2. That the continue block is a branch to the header. // 3. That the only phi used in the loop is the induction variable. - // TODO(stephen@codeplay.com): This is a temporary mesure, after the loop is + // TODO(stephen@codeplay.com): This is a temporary measure, after the loop is // converted into LCSAA form and has a single entry and exit we can rewrite // the other phis. // 4. That this is an inner most loop, or that loops contained within this diff --git a/source/opt/merge_return_pass.cpp b/source/opt/merge_return_pass.cpp index a962a7cc..7710deae 100644 --- a/source/opt/merge_return_pass.cpp +++ b/source/opt/merge_return_pass.cpp @@ -431,6 +431,7 @@ bool MergeReturnPass::BreakFromConstruct( std::list<BasicBlock*>* order, Instruction* break_merge_inst) { // Make sure the CFG is build here. If we don't then it becomes very hard // to know which new blocks need to be updated. + context()->InvalidateAnalyses(IRContext::kAnalysisCFG); context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG); // When predicating, be aware of whether this block is a header block, a diff --git a/source/opt/merge_return_pass.h b/source/opt/merge_return_pass.h index 4096ce7d..a35cf269 100644 --- a/source/opt/merge_return_pass.h +++ b/source/opt/merge_return_pass.h @@ -247,7 +247,7 @@ class MergeReturnPass : public MemPass { // Add new phi nodes for any id that no longer dominate all of it uses. A phi // node is added to a block |bb| for an id if the id is defined between the - // original immediate dominator of |bb| and its new immidiate dominator. It + // original immediate dominator of |bb| and its new immediate dominator. It // is assumed that at this point there are no unreachable blocks in the // control flow graph. void AddNewPhiNodes(); @@ -273,7 +273,7 @@ class MergeReturnPass : public MemPass { void InsertAfterElement(BasicBlock* element, BasicBlock* new_element, std::list<BasicBlock*>* list); - // Creates a single case switch around all of the exectuable code of the + // Creates a single case switch around all of the executable code of the // current function where the switch and case value are both zero and the // default is the merge block. Returns after the switch is executed. Sets // |final_return_block_|. diff --git a/source/opt/module.cpp b/source/opt/module.cpp index f97defbd..5983abb1 100644 --- a/source/opt/module.cpp +++ b/source/opt/module.cpp @@ -139,7 +139,7 @@ void Module::ToBinary(std::vector<uint32_t>* binary, bool skip_nop) const { // TODO(antiagainst): should we change the generator number? binary->push_back(header_.generator); binary->push_back(header_.bound); - binary->push_back(header_.reserved); + binary->push_back(header_.schema); size_t bound_idx = binary->size() - 2; DebugScope last_scope(kNoDebugScope, kNoInlinedAt); @@ -260,9 +260,7 @@ bool Module::HasExplicitCapability(uint32_t cap) { uint32_t Module::GetExtInstImportId(const char* extstr) { for (auto& ei : ext_inst_imports_) - if (!strcmp(extstr, - reinterpret_cast<const char*>(&(ei.GetInOperand(0).words[0])))) - return ei.result_id(); + if (!ei.GetInOperand(0).AsString().compare(extstr)) return ei.result_id(); return 0; } diff --git a/source/opt/module.h b/source/opt/module.h index 0360b7d5..230be709 100644 --- a/source/opt/module.h +++ b/source/opt/module.h @@ -36,7 +36,7 @@ struct ModuleHeader { uint32_t version; uint32_t generator; uint32_t bound; - uint32_t reserved; + uint32_t schema; }; // A SPIR-V module. It contains all the information for a SPIR-V module and @@ -61,7 +61,7 @@ class Module { } // Returns the Id bound. - uint32_t IdBound() { return header_.bound; } + uint32_t IdBound() const { return header_.bound; } // Returns the current Id bound and increases it to the next available value. // If the id bound has already reached its maximum value, then 0 is returned. @@ -141,6 +141,8 @@ class Module { inline uint32_t id_bound() const { return header_.bound; } inline uint32_t version() const { return header_.version; } + inline uint32_t generator() const { return header_.generator; } + inline uint32_t schema() const { return header_.schema; } inline void set_version(uint32_t v) { header_.version = v; } diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index e74db26f..f28b1baf 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -288,6 +288,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) { RegisterPass(CreateStripDebugInfoPass()); } else if (pass_name == "strip-reflect") { RegisterPass(CreateStripReflectInfoPass()); + } else if (pass_name == "strip-nonsemantic") { + RegisterPass(CreateStripNonSemanticInfoPass()); } else if (pass_name == "set-spec-const-default-value") { if (pass_args.size() > 0) { auto spec_ids_vals = @@ -322,6 +324,8 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) { RegisterPass(CreateLocalAccessChainConvertPass()); } else if (pass_name == "replace-desc-array-access-using-var-index") { RegisterPass(CreateReplaceDescArrayAccessUsingVarIndexPass()); + } else if (pass_name == "spread-volatile-semantics") { + RegisterPass(CreateSpreadVolatileSemanticsPass()); } else if (pass_name == "descriptor-scalar-replacement") { RegisterPass(CreateDescriptorScalarReplacementPass()); } else if (pass_name == "eliminate-dead-code-aggressive") { @@ -517,6 +521,10 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag) { RegisterPass(CreateAmdExtToKhrPass()); } else if (pass_name == "interpolate-fixup") { RegisterPass(CreateInterpolateFixupPass()); + } else if (pass_name == "remove-dont-inline") { + RegisterPass(CreateRemoveDontInlinePass()); + } else if (pass_name == "eliminate-dead-input-components") { + RegisterPass(CreateEliminateDeadInputComponentsPass()); } else if (pass_name == "convert-to-sampled-image") { if (pass_args.size() > 0) { auto descriptor_set_binding_pairs = @@ -653,8 +661,12 @@ Optimizer::PassToken CreateStripDebugInfoPass() { } Optimizer::PassToken CreateStripReflectInfoPass() { + return CreateStripNonSemanticInfoPass(); +} + +Optimizer::PassToken CreateStripNonSemanticInfoPass() { return MakeUnique<Optimizer::PassToken::Impl>( - MakeUnique<opt::StripReflectInfoPass>()); + MakeUnique<opt::StripNonSemanticInfoPass>()); } Optimizer::PassToken CreateEliminateDeadFunctionsPass() { @@ -764,6 +776,11 @@ Optimizer::PassToken CreateLocalMultiStoreElimPass() { MakeUnique<opt::SSARewritePass>()); } +Optimizer::PassToken CreateAggressiveDCEPass() { + return MakeUnique<Optimizer::PassToken::Impl>( + MakeUnique<opt::AggressiveDCEPass>(false)); +} + Optimizer::PassToken CreateAggressiveDCEPass(bool preserve_interface) { return MakeUnique<Optimizer::PassToken::Impl>( MakeUnique<opt::AggressiveDCEPass>(preserve_interface)); @@ -965,6 +982,11 @@ Optimizer::PassToken CreateReplaceDescArrayAccessUsingVarIndexPass() { MakeUnique<opt::ReplaceDescArrayAccessUsingVarIndex>()); } +Optimizer::PassToken CreateSpreadVolatileSemanticsPass() { + return MakeUnique<Optimizer::PassToken::Impl>( + MakeUnique<opt::SpreadVolatileSemantics>()); +} + Optimizer::PassToken CreateDescriptorScalarReplacementPass() { return MakeUnique<Optimizer::PassToken::Impl>( MakeUnique<opt::DescriptorScalarReplacement>()); @@ -984,6 +1006,11 @@ Optimizer::PassToken CreateInterpolateFixupPass() { MakeUnique<opt::InterpFixupPass>()); } +Optimizer::PassToken CreateEliminateDeadInputComponentsPass() { + return MakeUnique<Optimizer::PassToken::Impl>( + MakeUnique<opt::EliminateDeadInputComponentsPass>()); +} + Optimizer::PassToken CreateConvertToSampledImagePass( const std::vector<opt::DescriptorSetAndBinding>& descriptor_set_binding_pairs) { @@ -991,4 +1018,8 @@ Optimizer::PassToken CreateConvertToSampledImagePass( MakeUnique<opt::ConvertToSampledImagePass>(descriptor_set_binding_pairs)); } +Optimizer::PassToken CreateRemoveDontInlinePass() { + return MakeUnique<Optimizer::PassToken::Impl>( + MakeUnique<opt::RemoveDontInline>()); +} } // namespace spvtools diff --git a/source/opt/pass.h b/source/opt/pass.h index a8c9c4b4..b2303e23 100644 --- a/source/opt/pass.h +++ b/source/opt/pass.h @@ -28,6 +28,13 @@ #include "spirv-tools/libspirv.hpp" #include "types.h" +// Avoid unused variable warning/error on Linux +#ifndef NDEBUG +#define USE_ASSERT(x) assert(x) +#else +#define USE_ASSERT(x) ((void)(x)) +#endif + namespace spvtools { namespace opt { @@ -129,7 +136,7 @@ class Pass { // Processes the given |module|. Returns Status::Failure if errors occur when // processing. Returns the corresponding Status::Success if processing is - // succesful to indicate whether changes are made to the module. + // successful to indicate whether changes are made to the module. virtual Status Process() = 0; // Return the next available SSA id and increment it. diff --git a/source/opt/pass_manager.cpp b/source/opt/pass_manager.cpp index be53d344..a73ff7cf 100644 --- a/source/opt/pass_manager.cpp +++ b/source/opt/pass_manager.cpp @@ -35,10 +35,18 @@ Pass::Status PassManager::Run(IRContext* context) { if (print_all_stream_) { std::vector<uint32_t> binary; context->module()->ToBinary(&binary, false); - SpirvTools t(SPV_ENV_UNIVERSAL_1_2); + SpirvTools t(target_env_); + t.SetMessageConsumer(consumer()); std::string disassembly; - t.Disassemble(binary, &disassembly, 0); - *print_all_stream_ << preamble << (pass ? pass->name() : "") << "\n" + std::string pass_name = (pass ? pass->name() : ""); + if (!t.Disassemble(binary, &disassembly, 0)) { + std::string msg = "Disassembly failed before pass "; + msg += pass_name + "\n"; + spv_position_t null_pos{0, 0, 0}; + consumer()(SPV_MSG_WARNING, "", null_pos, msg.c_str()); + return; + } + *print_all_stream_ << preamble << pass_name << "\n" << disassembly << std::endl; } }; diff --git a/source/opt/pass_manager.h b/source/opt/pass_manager.h index 9686dddc..11961a33 100644 --- a/source/opt/pass_manager.h +++ b/source/opt/pass_manager.h @@ -54,7 +54,7 @@ class PassManager { // Adds an externally constructed pass. void AddPass(std::unique_ptr<Pass> pass); // Uses the argument |args| to construct a pass instance of type |T|, and adds - // the pass instance to this pass manger. The pass added will use this pass + // the pass instance to this pass manager. The pass added will use this pass // manager's message consumer. template <typename T, typename... Args> void AddPass(Args&&... args); @@ -70,7 +70,7 @@ class PassManager { // Runs all passes on the given |module|. Returns Status::Failure if errors // occur when processing using one of the registered passes. All passes // registered after the error-reporting pass will be skipped. Returns the - // corresponding Status::Success if processing is succesful to indicate + // corresponding Status::Success if processing is successful to indicate // whether changes are made to the module. // // After running all the passes, they are removed from the list. diff --git a/source/opt/passes.h b/source/opt/passes.h index f3c30d57..a12c76b8 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -34,6 +34,7 @@ #include "source/opt/desc_sroa.h" #include "source/opt/eliminate_dead_constant_pass.h" #include "source/opt/eliminate_dead_functions_pass.h" +#include "source/opt/eliminate_dead_input_components_pass.h" #include "source/opt/eliminate_dead_members_pass.h" #include "source/opt/empty_pass.h" #include "source/opt/fix_storage_class.h" @@ -64,6 +65,7 @@ #include "source/opt/reduce_load_size.h" #include "source/opt/redundancy_elimination.h" #include "source/opt/relax_float_ops_pass.h" +#include "source/opt/remove_dontinline_pass.h" #include "source/opt/remove_duplicates_pass.h" #include "source/opt/remove_unused_interface_variables_pass.h" #include "source/opt/replace_desc_array_access_using_var_index.h" @@ -71,10 +73,11 @@ #include "source/opt/scalar_replacement_pass.h" #include "source/opt/set_spec_constant_default_value_pass.h" #include "source/opt/simplification_pass.h" +#include "source/opt/spread_volatile_semantics.h" #include "source/opt/ssa_rewrite_pass.h" #include "source/opt/strength_reduction_pass.h" #include "source/opt/strip_debug_info_pass.h" -#include "source/opt/strip_reflect_info_pass.h" +#include "source/opt/strip_nonsemantic_info_pass.h" #include "source/opt/unify_const_pass.h" #include "source/opt/upgrade_memory_model.h" #include "source/opt/vector_dce.h" diff --git a/source/opt/private_to_local_pass.cpp b/source/opt/private_to_local_pass.cpp index 12a226d5..80fb4c53 100644 --- a/source/opt/private_to_local_pass.cpp +++ b/source/opt/private_to_local_pass.cpp @@ -135,7 +135,7 @@ bool PrivateToLocalPass::MoveVariable(Instruction* variable, // Place the variable at the start of the first basic block. context()->AnalyzeUses(variable); context()->set_instr_block(variable, &*function->begin()); - function->begin()->begin()->InsertBefore(move(var)); + function->begin()->begin()->InsertBefore(std::move(var)); // Update uses where the type may have changed. return UpdateUses(variable); diff --git a/source/opt/private_to_local_pass.h b/source/opt/private_to_local_pass.h index c6127d67..e96a965e 100644 --- a/source/opt/private_to_local_pass.h +++ b/source/opt/private_to_local_pass.h @@ -44,7 +44,7 @@ class PrivateToLocalPass : public Pass { // class of |function|. Returns false if the variable could not be moved. bool MoveVariable(Instruction* variable, Function* function); - // |inst| is an instruction declaring a varible. If that variable is + // |inst| is an instruction declaring a variable. If that variable is // referenced in a single function and all of uses are valid as defined by // |IsValidUse|, then that function is returned. Otherwise, the return // value is |nullptr|. diff --git a/source/opt/redundancy_elimination.h b/source/opt/redundancy_elimination.h index 91809b5d..40451f40 100644 --- a/source/opt/redundancy_elimination.h +++ b/source/opt/redundancy_elimination.h @@ -41,7 +41,7 @@ class RedundancyEliminationPass : public LocalRedundancyEliminationPass { // in the function containing |bb|. // // |value_to_ids| is a map from value number to ids. If {vn, id} is in - // |value_to_ids| then vn is the value number of id, and the defintion of id + // |value_to_ids| then vn is the value number of id, and the definition of id // dominates |bb|. // // Returns true if at least one instruction is deleted. diff --git a/source/opt/register_pressure.cpp b/source/opt/register_pressure.cpp index 5750c6d4..1ad33738 100644 --- a/source/opt/register_pressure.cpp +++ b/source/opt/register_pressure.cpp @@ -378,7 +378,7 @@ void RegisterLiveness::SimulateFusion( // The loop fusion is injecting the l1 before the l2, the latch of l1 will be // connected to the header of l2. // To compute the register usage, we inject the loop live-in (union of l1 and - // l2 live-in header blocks) into the the live in/out of each basic block of + // l2 live-in header blocks) into the live in/out of each basic block of // l1 to get the peak register usage. We then repeat the operation to for l2 // basic blocks but in this case we inject the live-out of the latch of l1. auto live_loop = MakeFilterIteratorRange( diff --git a/source/opt/remove_dontinline_pass.cpp b/source/opt/remove_dontinline_pass.cpp new file mode 100644 index 00000000..4dd1cd4f --- /dev/null +++ b/source/opt/remove_dontinline_pass.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/remove_dontinline_pass.h" + +namespace spvtools { +namespace opt { + +Pass::Status RemoveDontInline::Process() { + bool modified = false; + modified = ClearDontInlineFunctionControl(); + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); +} + +bool RemoveDontInline::ClearDontInlineFunctionControl() { + bool modified = false; + for (auto& func : *get_module()) { + ClearDontInlineFunctionControl(&func); + } + return modified; +} + +bool RemoveDontInline::ClearDontInlineFunctionControl(Function* function) { + constexpr uint32_t kFunctionControlInOperandIdx = 0; + Instruction* function_inst = &function->DefInst(); + uint32_t function_control = + function_inst->GetSingleWordInOperand(kFunctionControlInOperandIdx); + + if ((function_control & SpvFunctionControlDontInlineMask) == 0) { + return false; + } + function_control &= ~SpvFunctionControlDontInlineMask; + function_inst->SetInOperand(kFunctionControlInOperandIdx, {function_control}); + return true; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/remove_dontinline_pass.h b/source/opt/remove_dontinline_pass.h new file mode 100644 index 00000000..16243199 --- /dev/null +++ b/source/opt/remove_dontinline_pass.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_REMOVE_DONTINLINE_PASS_H_ +#define SOURCE_OPT_REMOVE_DONTINLINE_PASS_H_ + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class RemoveDontInline : public Pass { + public: + const char* name() const override { return "remove-dont-inline"; } + Status Process() override; + + private: + // Clears the DontInline function control from every function in the module. + // Returns true of a change was made. + bool ClearDontInlineFunctionControl(); + + // Clears the DontInline function control from |function|. + // Returns true of a change was made. + bool ClearDontInlineFunctionControl(Function* function); +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_REMOVE_DONTINLINE_PASS_H_ diff --git a/source/opt/remove_duplicates_pass.cpp b/source/opt/remove_duplicates_pass.cpp index 0e65cc8d..1ed8e2a0 100644 --- a/source/opt/remove_duplicates_pass.cpp +++ b/source/opt/remove_duplicates_pass.cpp @@ -72,9 +72,8 @@ bool RemoveDuplicatesPass::RemoveDuplicatesExtInstImports() const { std::unordered_map<std::string, SpvId> ext_inst_imports; for (auto* i = &*context()->ext_inst_import_begin(); i;) { - auto res = ext_inst_imports.emplace( - reinterpret_cast<const char*>(i->GetInOperand(0u).words.data()), - i->result_id()); + auto res = ext_inst_imports.emplace(i->GetInOperand(0u).AsString(), + i->result_id()); if (res.second) { // Never seen before, keep it. i = i->NextNode(); diff --git a/source/opt/replace_desc_array_access_using_var_index.cpp b/source/opt/replace_desc_array_access_using_var_index.cpp index 1082e679..4cadf600 100644 --- a/source/opt/replace_desc_array_access_using_var_index.cpp +++ b/source/opt/replace_desc_array_access_using_var_index.cpp @@ -253,8 +253,12 @@ void ReplaceDescArrayAccessUsingVarIndex::ReplaceNonUniformAccessWithSwitchCase( Instruction* access_chain_final_user, Instruction* access_chain, uint32_t number_of_elements, const std::deque<Instruction*>& insts_to_be_cloned) const { - // Create merge block and add terminator auto* block = context()->get_instr_block(access_chain_final_user); + // If the instruction does not belong to a block (i.e. in the case of + // OpDecorate), no replacement is needed. + if (!block) return; + + // Create merge block and add terminator auto* merge_block = SeparateInstructionsIntoNewBlock( block, access_chain_final_user->NextNode()); diff --git a/source/opt/replace_desc_array_access_using_var_index.h b/source/opt/replace_desc_array_access_using_var_index.h index e18222c8..0c97f7eb 100644 --- a/source/opt/replace_desc_array_access_using_var_index.h +++ b/source/opt/replace_desc_array_access_using_var_index.h @@ -47,7 +47,7 @@ class ReplaceDescArrayAccessUsingVarIndex : public Pass { } private: - // Replaces all acceses to |var| using variable indices with constant + // Replaces all accesses to |var| using variable indices with constant // elements of the array |var|. Creates switch-case statements to determine // the value of the variable index for all the possible cases. Returns // whether replacement is done or not. @@ -170,7 +170,7 @@ class ReplaceDescArrayAccessUsingVarIndex : public Pass { // Creates and adds an OpSwitch used for the selection of OpAccessChain whose // first Indexes operand is |access_chain_index_var_id|. The OpSwitch will be // added at the end of |parent_block|. It will jump to |default_id| for the - // default case and jumps to one of case blocks whoes ids are |case_block_ids| + // default case and jumps to one of case blocks whose ids are |case_block_ids| // if |access_chain_index_var_id| matches the case number. |merge_id| is the // merge block id. void AddSwitchForAccessChain( diff --git a/source/opt/replace_invalid_opc.cpp b/source/opt/replace_invalid_opc.cpp index e3b9d3e4..1dcd06f5 100644 --- a/source/opt/replace_invalid_opc.cpp +++ b/source/opt/replace_invalid_opc.cpp @@ -112,8 +112,7 @@ bool ReplaceInvalidOpcodePass::RewriteFunction(Function* function, } Instruction* file_name = context()->get_def_use_mgr()->GetDef(file_name_id); - const char* source = reinterpret_cast<const char*>( - &file_name->GetInOperand(0).words[0]); + const std::string source = file_name->GetInOperand(0).AsString(); // Get the line number and column number. uint32_t line_number = @@ -121,7 +120,7 @@ bool ReplaceInvalidOpcodePass::RewriteFunction(Function* function, uint32_t col_number = last_line_dbg_inst->GetSingleWordInOperand(2); // Replace the instruction. - ReplaceInstruction(inst, source, line_number, col_number); + ReplaceInstruction(inst, source.c_str(), line_number, col_number); } } }, diff --git a/source/opt/scalar_analysis.cpp b/source/opt/scalar_analysis.cpp index 38555e64..2b0a824c 100644 --- a/source/opt/scalar_analysis.cpp +++ b/source/opt/scalar_analysis.cpp @@ -581,7 +581,7 @@ static void PushToString(T id, std::u32string* str) { // Implements the hashing of SENodes. size_t SENodeHash::operator()(const SENode* node) const { - // Concatinate the terms into a string which we can hash. + // Concatenate the terms into a string which we can hash. std::u32string hash_string{}; // Hashing the type as a string is safer than hashing the enum as the enum is diff --git a/source/opt/scalar_analysis_nodes.h b/source/opt/scalar_analysis_nodes.h index b0e3fefd..91ce446f 100644 --- a/source/opt/scalar_analysis_nodes.h +++ b/source/opt/scalar_analysis_nodes.h @@ -167,7 +167,7 @@ class SENode { const ChildContainerType& GetChildren() const { return children_; } ChildContainerType& GetChildren() { return children_; } - // Return true if this node is a cant compute node. + // Return true if this node is a can't compute node. bool IsCantCompute() const { return GetType() == CanNotCompute; } // Implements a casting method for each type. diff --git a/source/opt/scalar_analysis_simplification.cpp b/source/opt/scalar_analysis_simplification.cpp index 52f2d6ad..3c1ecc08 100644 --- a/source/opt/scalar_analysis_simplification.cpp +++ b/source/opt/scalar_analysis_simplification.cpp @@ -88,7 +88,7 @@ class SENodeSimplifyImpl { private: // Recursively descend through the graph to build up the accumulator objects - // which are used to flatten the graph. |child| is the node currenty being + // which are used to flatten the graph. |child| is the node currently being // traversed and the |negation| flag is used to signify that this operation // was preceded by a unary negative operation and as such the result should be // negated. @@ -134,7 +134,7 @@ class SENodeSimplifyImpl { // offset. SENode* EliminateZeroCoefficientRecurrents(SENode* node); - // A reference the the analysis which requested the simplification. + // A reference the analysis which requested the simplification. ScalarEvolutionAnalysis& analysis_; // The node being simplified. diff --git a/source/opt/scalar_replacement_pass.cpp b/source/opt/scalar_replacement_pass.cpp index 4d6a7aad..e27c828b 100644 --- a/source/opt/scalar_replacement_pass.cpp +++ b/source/opt/scalar_replacement_pass.cpp @@ -24,6 +24,7 @@ #include "source/opt/reflect.h" #include "source/opt/types.h" #include "source/util/make_unique.h" +#include "types.h" static const uint32_t kDebugValueOperandValueIndex = 5; static const uint32_t kDebugValueOperandExpressionIndex = 6; @@ -395,7 +396,7 @@ bool ScalarReplacementPass::CreateReplacementVariables( if (!components_used || components_used->count(elem)) { CreateVariable(*id, inst, elem, replacements); } else { - replacements->push_back(CreateNullConstant(*id)); + replacements->push_back(GetUndef(*id)); } elem++; }); @@ -406,8 +407,8 @@ bool ScalarReplacementPass::CreateReplacementVariables( CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements); } else { - replacements->push_back( - CreateNullConstant(type->GetSingleWordInOperand(0u))); + uint32_t element_type_id = type->GetSingleWordInOperand(0); + replacements->push_back(GetUndef(element_type_id)); } } break; @@ -429,6 +430,10 @@ bool ScalarReplacementPass::CreateReplacementVariables( replacements->end(); } +Instruction* ScalarReplacementPass::GetUndef(uint32_t type_id) { + return get_def_use_mgr()->GetDef(Type2Undef(type_id)); +} + void ScalarReplacementPass::TransferAnnotations( const Instruction* source, std::vector<Instruction*>* replacements) { // Only transfer invariant and restrict decorations on the variable. There are @@ -981,20 +986,6 @@ ScalarReplacementPass::GetUsedComponents(Instruction* inst) { return result; } -Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) { - analysis::TypeManager* type_mgr = context()->get_type_mgr(); - analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); - - const analysis::Type* type = type_mgr->GetType(type_id); - const analysis::Constant* null_const = const_mgr->GetConstant(type, {}); - Instruction* null_inst = - const_mgr->GetDefiningInstruction(null_const, type_id); - if (null_inst != nullptr) { - context()->UpdateDefUse(null_inst); - } - return null_inst; -} - uint64_t ScalarReplacementPass::GetMaxLegalIndex( const Instruction* var_inst) const { assert(var_inst->opcode() == SpvOpVariable && diff --git a/source/opt/scalar_replacement_pass.h b/source/opt/scalar_replacement_pass.h index 0928830c..76afc267 100644 --- a/source/opt/scalar_replacement_pass.h +++ b/source/opt/scalar_replacement_pass.h @@ -23,14 +23,14 @@ #include <vector> #include "source/opt/function.h" -#include "source/opt/pass.h" +#include "source/opt/mem_pass.h" #include "source/opt/type_manager.h" namespace spvtools { namespace opt { // Documented in optimizer.hpp -class ScalarReplacementPass : public Pass { +class ScalarReplacementPass : public MemPass { private: static const uint32_t kDefaultLimit = 100; @@ -234,10 +234,8 @@ class ScalarReplacementPass : public Pass { std::unique_ptr<std::unordered_set<int64_t>> GetUsedComponents( Instruction* inst); - // Returns an instruction defining a null constant with type |type_id|. If - // one already exists, it is returned. Otherwise a new one is created. - // Returns |nullptr| if the new constant could not be created. - Instruction* CreateNullConstant(uint32_t type_id); + // Returns an instruction defining an undefined value type |type_id|. + Instruction* GetUndef(uint32_t type_id); // Maps storage type to a pointer type enclosing that type. std::unordered_map<uint32_t, uint32_t> pointee_to_pointer_; diff --git a/source/opt/spread_volatile_semantics.cpp b/source/opt/spread_volatile_semantics.cpp new file mode 100644 index 00000000..a1d34329 --- /dev/null +++ b/source/opt/spread_volatile_semantics.cpp @@ -0,0 +1,318 @@ +// Copyright (c) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/spread_volatile_semantics.h" + +#include "source/opt/decoration_manager.h" +#include "source/opt/ir_builder.h" +#include "source/spirv_constant.h" + +namespace spvtools { +namespace opt { +namespace { + +const uint32_t kOpDecorateInOperandBuiltinDecoration = 2u; +const uint32_t kOpLoadInOperandMemoryOperands = 1u; +const uint32_t kOpEntryPointInOperandEntryPoint = 1u; +const uint32_t kOpEntryPointInOperandInterface = 3u; + +bool HasBuiltinDecoration(analysis::DecorationManager* decoration_manager, + uint32_t var_id, uint32_t built_in) { + return decoration_manager->FindDecoration( + var_id, SpvDecorationBuiltIn, [built_in](const Instruction& inst) { + return built_in == inst.GetSingleWordInOperand( + kOpDecorateInOperandBuiltinDecoration); + }); +} + +bool IsBuiltInForRayTracingVolatileSemantics(uint32_t built_in) { + switch (built_in) { + case SpvBuiltInSMIDNV: + case SpvBuiltInWarpIDNV: + case SpvBuiltInSubgroupSize: + case SpvBuiltInSubgroupLocalInvocationId: + case SpvBuiltInSubgroupEqMask: + case SpvBuiltInSubgroupGeMask: + case SpvBuiltInSubgroupGtMask: + case SpvBuiltInSubgroupLeMask: + case SpvBuiltInSubgroupLtMask: + return true; + default: + return false; + } +} + +bool HasBuiltinForRayTracingVolatileSemantics( + analysis::DecorationManager* decoration_manager, uint32_t var_id) { + return decoration_manager->FindDecoration( + var_id, SpvDecorationBuiltIn, [](const Instruction& inst) { + uint32_t built_in = + inst.GetSingleWordInOperand(kOpDecorateInOperandBuiltinDecoration); + return IsBuiltInForRayTracingVolatileSemantics(built_in); + }); +} + +bool HasVolatileDecoration(analysis::DecorationManager* decoration_manager, + uint32_t var_id) { + return decoration_manager->HasDecoration(var_id, SpvDecorationVolatile); +} + +bool HasOnlyEntryPointsAsFunctions(IRContext* context, Module* module) { + std::unordered_set<uint32_t> entry_function_ids; + for (Instruction& entry_point : module->entry_points()) { + entry_function_ids.insert( + entry_point.GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint)); + } + for (auto& function : *module) { + if (entry_function_ids.find(function.result_id()) == + entry_function_ids.end()) { + std::string message( + "Functions of SPIR-V for spread-volatile-semantics pass input must " + "be inlined except entry points"); + message += "\n " + function.DefInst().PrettyPrint( + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + context->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); + return false; + } + } + return true; +} + +} // namespace + +Pass::Status SpreadVolatileSemantics::Process() { + if (HasNoExecutionModel()) { + return Status::SuccessWithoutChange; + } + + if (!HasOnlyEntryPointsAsFunctions(context(), get_module())) { + return Status::Failure; + } + + const bool is_vk_memory_model_enabled = + context()->get_feature_mgr()->HasCapability( + SpvCapabilityVulkanMemoryModel); + CollectTargetsForVolatileSemantics(is_vk_memory_model_enabled); + + // If VulkanMemoryModel capability is not enabled, we have to set Volatile + // decoration for interface variables instead of setting Volatile for load + // instructions. If an interface (or pointers to it) is used by two load + // instructions in two entry points and one must be volatile while another + // is not, we have to report an error for the conflict. + if (!is_vk_memory_model_enabled && + HasInterfaceInConflictOfVolatileSemantics()) { + return Status::Failure; + } + + return SpreadVolatileSemanticsToVariables(is_vk_memory_model_enabled); +} + +Pass::Status SpreadVolatileSemantics::SpreadVolatileSemanticsToVariables( + const bool is_vk_memory_model_enabled) { + Status status = Status::SuccessWithoutChange; + for (Instruction& var : context()->types_values()) { + auto entry_function_ids = + EntryFunctionsToSpreadVolatileSemanticsForVar(var.result_id()); + if (entry_function_ids.empty()) { + continue; + } + + if (is_vk_memory_model_enabled) { + SetVolatileForLoadsInEntries(&var, entry_function_ids); + } else { + DecorateVarWithVolatile(&var); + } + status = Status::SuccessWithChange; + } + return status; +} + +bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint( + uint32_t var_id, Instruction* entry_point) { + uint32_t entry_function_id = + entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint); + return !VisitLoadsOfPointersToVariableInEntries( + var_id, + [](Instruction* load) { + // If it has a load without volatile memory operand, finish traversal + // and return false. + if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) { + return false; + } + uint32_t memory_operands = + load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands); + return (memory_operands & SpvMemoryAccessVolatileMask) != 0; + }, + {entry_function_id}); +} + +bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() { + for (Instruction& entry_point : get_module()->entry_points()) { + SpvExecutionModel execution_model = + static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0)); + for (uint32_t operand_index = kOpEntryPointInOperandInterface; + operand_index < entry_point.NumInOperands(); ++operand_index) { + uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index); + if (!EntryFunctionsToSpreadVolatileSemanticsForVar(var_id).empty() && + !IsTargetForVolatileSemantics(var_id, execution_model) && + IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) { + Instruction* inst = context()->get_def_use_mgr()->GetDef(var_id); + context()->EmitErrorMessage( + "Variable is a target for Volatile semantics for an entry point, " + "but it is not for another entry point", + inst); + return true; + } + } + } + return false; +} + +void SpreadVolatileSemantics::MarkVolatileSemanticsForVariable( + uint32_t var_id, Instruction* entry_point) { + uint32_t entry_function_id = + entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint); + auto itr = var_ids_to_entry_fn_for_volatile_semantics_.find(var_id); + if (itr == var_ids_to_entry_fn_for_volatile_semantics_.end()) { + var_ids_to_entry_fn_for_volatile_semantics_[var_id] = {entry_function_id}; + return; + } + itr->second.insert(entry_function_id); +} + +void SpreadVolatileSemantics::CollectTargetsForVolatileSemantics( + const bool is_vk_memory_model_enabled) { + for (Instruction& entry_point : get_module()->entry_points()) { + SpvExecutionModel execution_model = + static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0)); + for (uint32_t operand_index = kOpEntryPointInOperandInterface; + operand_index < entry_point.NumInOperands(); ++operand_index) { + uint32_t var_id = entry_point.GetSingleWordInOperand(operand_index); + if (!IsTargetForVolatileSemantics(var_id, execution_model)) { + continue; + } + if (is_vk_memory_model_enabled || + IsTargetUsedByNonVolatileLoadInEntryPoint(var_id, &entry_point)) { + MarkVolatileSemanticsForVariable(var_id, &entry_point); + } + } + } +} + +void SpreadVolatileSemantics::DecorateVarWithVolatile(Instruction* var) { + analysis::DecorationManager* decoration_manager = + context()->get_decoration_mgr(); + uint32_t var_id = var->result_id(); + if (HasVolatileDecoration(decoration_manager, var_id)) { + return; + } + get_decoration_mgr()->AddDecoration( + SpvOpDecorate, {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}}, + {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, + {SpvDecorationVolatile}}}); +} + +bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries( + uint32_t var_id, const std::function<bool(Instruction*)>& handle_load, + const std::unordered_set<uint32_t>& entry_function_ids) { + std::vector<uint32_t> worklist({var_id}); + auto* def_use_mgr = context()->get_def_use_mgr(); + while (!worklist.empty()) { + uint32_t ptr_id = worklist.back(); + worklist.pop_back(); + bool finish_traversal = !def_use_mgr->WhileEachUser( + ptr_id, [this, &worklist, &ptr_id, handle_load, + &entry_function_ids](Instruction* user) { + BasicBlock* block = context()->get_instr_block(user); + if (block == nullptr || + entry_function_ids.find(block->GetParent()->result_id()) == + entry_function_ids.end()) { + return true; + } + + if (user->opcode() == SpvOpAccessChain || + user->opcode() == SpvOpInBoundsAccessChain || + user->opcode() == SpvOpPtrAccessChain || + user->opcode() == SpvOpInBoundsPtrAccessChain || + user->opcode() == SpvOpCopyObject) { + if (ptr_id == user->GetSingleWordInOperand(0)) + worklist.push_back(user->result_id()); + return true; + } + + if (user->opcode() != SpvOpLoad) { + return true; + } + + return handle_load(user); + }); + if (finish_traversal) return false; + } + return true; +} + +void SpreadVolatileSemantics::SetVolatileForLoadsInEntries( + Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) { + // Set Volatile memory operand for all load instructions if they do not have + // it. + VisitLoadsOfPointersToVariableInEntries( + var->result_id(), + [](Instruction* load) { + if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) { + load->AddOperand( + {SPV_OPERAND_TYPE_MEMORY_ACCESS, {SpvMemoryAccessVolatileMask}}); + return true; + } + uint32_t memory_operands = + load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands); + memory_operands |= SpvMemoryAccessVolatileMask; + load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands}); + return true; + }, + entry_function_ids); +} + +bool SpreadVolatileSemantics::IsTargetForVolatileSemantics( + uint32_t var_id, SpvExecutionModel execution_model) { + analysis::DecorationManager* decoration_manager = + context()->get_decoration_mgr(); + if (execution_model == SpvExecutionModelFragment) { + return get_module()->version() >= SPV_SPIRV_VERSION_WORD(1, 6) && + HasBuiltinDecoration(decoration_manager, var_id, + SpvBuiltInHelperInvocation); + } + + if (execution_model == SpvExecutionModelIntersectionKHR || + execution_model == SpvExecutionModelIntersectionNV) { + if (HasBuiltinDecoration(decoration_manager, var_id, + SpvBuiltInRayTmaxKHR)) { + return true; + } + } + + switch (execution_model) { + case SpvExecutionModelRayGenerationKHR: + case SpvExecutionModelClosestHitKHR: + case SpvExecutionModelMissKHR: + case SpvExecutionModelCallableKHR: + case SpvExecutionModelIntersectionKHR: + return HasBuiltinForRayTracingVolatileSemantics(decoration_manager, + var_id); + default: + return false; + } +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/spread_volatile_semantics.h b/source/opt/spread_volatile_semantics.h new file mode 100644 index 00000000..531a21d5 --- /dev/null +++ b/source/opt/spread_volatile_semantics.h @@ -0,0 +1,117 @@ +// Copyright (c) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_SPREAD_VOLATILE_SEMANTICS_H_ +#define SOURCE_OPT_SPREAD_VOLATILE_SEMANTICS_H_ + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +class SpreadVolatileSemantics : public Pass { + public: + SpreadVolatileSemantics() {} + + const char* name() const override { return "spread-volatile-semantics"; } + + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDefUse | IRContext::kAnalysisDecorations | + IRContext::kAnalysisInstrToBlockMapping; + } + + private: + // Returns true if it does not have an execution model. Linkage shaders do not + // have an execution model. + bool HasNoExecutionModel() { + return get_module()->entry_points().empty() && + context()->get_feature_mgr()->HasCapability(SpvCapabilityLinkage); + } + + // Iterates interface variables and spreads the Volatile semantics if it has + // load instructions for the Volatile semantics. + Pass::Status SpreadVolatileSemanticsToVariables( + const bool is_vk_memory_model_enabled); + + // Returns whether |var_id| is the result id of a target builtin variable for + // the volatile semantics for |execution_model| based on the Vulkan spec + // VUID-StandaloneSpirv-VulkanMemoryModel-04678 or + // VUID-StandaloneSpirv-VulkanMemoryModel-04679. + bool IsTargetForVolatileSemantics(uint32_t var_id, + SpvExecutionModel execution_model); + + // Collects interface variables that need the volatile semantics. + // |is_vk_memory_model_enabled| is true if VulkanMemoryModel capability is + // enabled. + void CollectTargetsForVolatileSemantics( + const bool is_vk_memory_model_enabled); + + // Reports an error if an interface variable is used by two entry points and + // it needs the Volatile decoration for one but not for another. Returns true + // if the error must be reported. + bool HasInterfaceInConflictOfVolatileSemantics(); + + // Returns whether the variable whose result is |var_id| is used by a + // non-volatile load or a pointer to it is used by a non-volatile load in + // |entry_point| or not. + bool IsTargetUsedByNonVolatileLoadInEntryPoint(uint32_t var_id, + Instruction* entry_point); + + // Visits load instructions of pointers to variable whose result id is + // |var_id| if the load instructions are in entry points whose + // function id is one of |entry_function_ids|. |handle_load| is a function to + // do some actions for the load instructions. Finishes the traversal and + // returns false if |handle_load| returns false for a load instruction. + // Otherwise, returns true after running |handle_load| for all the load + // instructions. + bool VisitLoadsOfPointersToVariableInEntries( + uint32_t var_id, const std::function<bool(Instruction*)>& handle_load, + const std::unordered_set<uint32_t>& entry_function_ids); + + // Sets Memory Operands of OpLoad instructions that load |var| or pointers + // of |var| as Volatile if the function id of the OpLoad instruction is + // included in |entry_function_ids|. + void SetVolatileForLoadsInEntries( + Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids); + + // Adds OpDecorate Volatile for |var| if it does not exist. + void DecorateVarWithVolatile(Instruction* var); + + // Returns a set of entry function ids to spread the volatile semantics for + // the variable with the result id |var_id|. + std::unordered_set<uint32_t> EntryFunctionsToSpreadVolatileSemanticsForVar( + uint32_t var_id) { + auto itr = var_ids_to_entry_fn_for_volatile_semantics_.find(var_id); + if (itr == var_ids_to_entry_fn_for_volatile_semantics_.end()) return {}; + return itr->second; + } + + // Specifies that we have to spread the volatile semantics for the + // variable with the result id |var_id| for the entry point |entry_point|. + void MarkVolatileSemanticsForVariable(uint32_t var_id, + Instruction* entry_point); + + // Result ids of variables to entry function ids for the volatile semantics + // spread. + std::unordered_map<uint32_t, std::unordered_set<uint32_t>> + var_ids_to_entry_fn_for_volatile_semantics_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_SPREAD_VOLATILE_SEMANTICS_H_ diff --git a/source/opt/strength_reduction_pass.h b/source/opt/strength_reduction_pass.h index 8dfeb307..1cbbbcc6 100644 --- a/source/opt/strength_reduction_pass.h +++ b/source/opt/strength_reduction_pass.h @@ -34,7 +34,7 @@ class StrengthReductionPass : public Pass { // Returns true if something changed. bool ReplaceMultiplyByPowerOf2(BasicBlock::iterator*); - // Scan the types and constants in the module looking for the the integer + // Scan the types and constants in the module looking for the integer // types that we are // interested in. The shift operation needs a small unsigned integer. We // need to find diff --git a/source/opt/strip_debug_info_pass.cpp b/source/opt/strip_debug_info_pass.cpp index c86ce578..6a0ebf24 100644 --- a/source/opt/strip_debug_info_pass.cpp +++ b/source/opt/strip_debug_info_pass.cpp @@ -14,6 +14,7 @@ #include "source/opt/strip_debug_info_pass.h" #include "source/opt/ir_context.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -21,9 +22,8 @@ namespace opt { Pass::Status StripDebugInfoPass::Process() { bool uses_non_semantic_info = false; for (auto& inst : context()->module()->extensions()) { - const char* ext_name = - reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]); - if (0 == std::strcmp(ext_name, "SPV_KHR_non_semantic_info")) { + const std::string ext_name = inst.GetInOperand(0).AsString(); + if (ext_name == "SPV_KHR_non_semantic_info") { uses_non_semantic_info = true; } } @@ -46,9 +46,10 @@ Pass::Status StripDebugInfoPass::Process() { if (use->opcode() == SpvOpExtInst) { auto ext_inst_set = def_use->GetDef(use->GetSingleWordInOperand(0u)); - const char* extension_name = reinterpret_cast<const char*>( - &ext_inst_set->GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12)) { + const std::string extension_name = + ext_inst_set->GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, + "NonSemantic.")) { // found a non-semantic use, return false as we cannot // remove this OpString return false; diff --git a/source/opt/strip_reflect_info_pass.cpp b/source/opt/strip_nonsemantic_info_pass.cpp index 8b0f2db7..cd1fbb63 100644 --- a/source/opt/strip_reflect_info_pass.cpp +++ b/source/opt/strip_nonsemantic_info_pass.cpp @@ -12,18 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "source/opt/strip_reflect_info_pass.h" +#include "source/opt/strip_nonsemantic_info_pass.h" #include <cstring> #include <vector> #include "source/opt/instruction.h" #include "source/opt/ir_context.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { -Pass::Status StripReflectInfoPass::Process() { +Pass::Status StripNonSemanticInfoPass::Process() { bool modified = false; std::vector<Instruction*> to_remove; @@ -32,7 +33,8 @@ Pass::Status StripReflectInfoPass::Process() { for (auto& inst : context()->module()->annotations()) { switch (inst.opcode()) { case SpvOpDecorateStringGOOGLE: - if (inst.GetSingleWordInOperand(1) == SpvDecorationHlslSemanticGOOGLE) { + if (inst.GetSingleWordInOperand(1) == SpvDecorationHlslSemanticGOOGLE || + inst.GetSingleWordInOperand(1) == SpvDecorationUserTypeGOOGLE) { to_remove.push_back(&inst); } else { other_uses_for_decorate_string = true; @@ -40,7 +42,8 @@ Pass::Status StripReflectInfoPass::Process() { break; case SpvOpMemberDecorateStringGOOGLE: - if (inst.GetSingleWordInOperand(2) == SpvDecorationHlslSemanticGOOGLE) { + if (inst.GetSingleWordInOperand(2) == SpvDecorationHlslSemanticGOOGLE || + inst.GetSingleWordInOperand(2) == SpvDecorationUserTypeGOOGLE) { to_remove.push_back(&inst); } else { other_uses_for_decorate_string = true; @@ -60,33 +63,26 @@ Pass::Status StripReflectInfoPass::Process() { } for (auto& inst : context()->module()->extensions()) { - const char* ext_name = - reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]); - if (0 == std::strcmp(ext_name, "SPV_GOOGLE_hlsl_functionality1")) { + const std::string ext_name = inst.GetInOperand(0).AsString(); + if (ext_name == "SPV_GOOGLE_hlsl_functionality1") { + to_remove.push_back(&inst); + } else if (ext_name == "SPV_GOOGLE_user_type") { to_remove.push_back(&inst); } else if (!other_uses_for_decorate_string && - 0 == std::strcmp(ext_name, "SPV_GOOGLE_decorate_string")) { + ext_name == "SPV_GOOGLE_decorate_string") { to_remove.push_back(&inst); - } else if (0 == std::strcmp(ext_name, "SPV_KHR_non_semantic_info")) { + } else if (ext_name == "SPV_KHR_non_semantic_info") { to_remove.push_back(&inst); } } - // clear all debug data now if it hasn't been cleared already, to remove any - // remaining OpString that may have been referenced by non-semantic extinsts - for (auto& dbg : context()->debugs1()) to_remove.push_back(&dbg); - for (auto& dbg : context()->debugs2()) to_remove.push_back(&dbg); - for (auto& dbg : context()->debugs3()) to_remove.push_back(&dbg); - for (auto& dbg : context()->ext_inst_debuginfo()) to_remove.push_back(&dbg); - // remove any extended inst imports that are non semantic std::unordered_set<uint32_t> non_semantic_sets; for (auto& inst : context()->module()->ext_inst_imports()) { assert(inst.opcode() == SpvOpExtInstImport && "Expecting an import of an extension's instruction set."); - const char* extension_name = - reinterpret_cast<const char*>(&inst.GetInOperand(0).words[0]); - if (0 == std::strncmp(extension_name, "NonSemantic.", 12)) { + const std::string extension_name = inst.GetInOperand(0).AsString(); + if (spvtools::utils::starts_with(extension_name, "NonSemantic.")) { non_semantic_sets.insert(inst.result_id()); to_remove.push_back(&inst); } @@ -103,19 +99,10 @@ Pass::Status StripReflectInfoPass::Process() { to_remove.push_back(inst); } } - }); + }, + true); } - // OpName must come first, since they may refer to other debug instructions. - // If they are after the instructions that refer to, then they will be killed - // when that instruction is killed, which will lead to a double kill. - std::sort(to_remove.begin(), to_remove.end(), - [](Instruction* lhs, Instruction* rhs) -> bool { - if (lhs->opcode() == SpvOpName && rhs->opcode() != SpvOpName) - return true; - return false; - }); - for (auto* inst : to_remove) { modified = true; context()->KillInst(inst); diff --git a/source/opt/strip_reflect_info_pass.h b/source/opt/strip_nonsemantic_info_pass.h index 4e1999ed..ff4e2e1d 100644 --- a/source/opt/strip_reflect_info_pass.h +++ b/source/opt/strip_nonsemantic_info_pass.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_ -#define SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_ +#ifndef SOURCE_OPT_STRIP_NONSEMANTIC_INFO_PASS_H_ +#define SOURCE_OPT_STRIP_NONSEMANTIC_INFO_PASS_H_ #include "source/opt/ir_context.h" #include "source/opt/module.h" @@ -23,9 +23,9 @@ namespace spvtools { namespace opt { // See optimizer.hpp for documentation. -class StripReflectInfoPass : public Pass { +class StripNonSemanticInfoPass : public Pass { public: - const char* name() const override { return "strip-reflect"; } + const char* name() const override { return "strip-nonsemantic"; } Status Process() override; // Return the mask of preserved Analyses. @@ -41,4 +41,4 @@ class StripReflectInfoPass : public Pass { } // namespace opt } // namespace spvtools -#endif // SOURCE_OPT_STRIP_REFLECT_INFO_PASS_H_ +#endif // SOURCE_OPT_STRIP_NONSEMANTIC_INFO_PASS_H_ diff --git a/source/opt/type_manager.cpp b/source/opt/type_manager.cpp index 7935ad33..a0006f55 100644 --- a/source/opt/type_manager.cpp +++ b/source/opt/type_manager.cpp @@ -23,6 +23,7 @@ #include "source/opt/log.h" #include "source/opt/reflect.h" #include "source/util/make_unique.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -234,6 +235,7 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { DefineParameterlessCase(PipeStorage); DefineParameterlessCase(NamedBarrier); DefineParameterlessCase(AccelerationStructureNV); + DefineParameterlessCase(RayQueryKHR); #undef DefineParameterlessCase case Type::kInteger: typeInst = MakeUnique<Instruction>( @@ -349,11 +351,8 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) { } case Type::kOpaque: { const Opaque* opaque = type->AsOpaque(); - size_t size = opaque->name().size(); // Convert to null-terminated packed UTF-8 string. - std::vector<uint32_t> words(size / 4 + 1, 0); - char* dst = reinterpret_cast<char*>(words.data()); - strncpy(dst, opaque->name().c_str(), size); + std::vector<uint32_t> words = spvtools::utils::MakeVector(opaque->name()); typeInst = MakeUnique<Instruction>( context(), SpvOpTypeOpaque, 0, id, std::initializer_list<Operand>{ @@ -529,6 +528,7 @@ Type* TypeManager::RebuildType(const Type& type) { DefineNoSubtypeCase(PipeStorage); DefineNoSubtypeCase(NamedBarrier); DefineNoSubtypeCase(AccelerationStructureNV); + DefineNoSubtypeCase(RayQueryKHR); #undef DefineNoSubtypeCase case Type::kVector: { const Vector* vec_ty = type.AsVector(); @@ -781,8 +781,7 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) { } } break; case SpvOpTypeOpaque: { - const uint32_t* data = inst.GetInOperand(0).words.data(); - type = new Opaque(reinterpret_cast<const char*>(data)); + type = new Opaque(inst.GetInOperand(0).AsString()); } break; case SpvOpTypePointer: { uint32_t pointee_type_id = inst.GetSingleWordInOperand(1); diff --git a/source/opt/type_manager.h b/source/opt/type_manager.h index ce9d83d4..72e37f48 100644 --- a/source/opt/type_manager.h +++ b/source/opt/type_manager.h @@ -160,6 +160,13 @@ class TypeManager { uint32_t GetFloatTypeId() { return GetTypeInstruction(GetFloatType()); } + Type* GetDoubleType() { + Float float_type(64); + return GetRegisteredType(&float_type); + } + + uint32_t GetDoubleTypeId() { return GetTypeInstruction(GetDoubleType()); } + Type* GetUIntVectorType(uint32_t size) { Vector vec_type(GetUIntType(), size); return GetRegisteredType(&vec_type); diff --git a/source/opt/types.cpp b/source/opt/types.cpp index b1eb3a50..ebbdc367 100644 --- a/source/opt/types.cpp +++ b/source/opt/types.cpp @@ -21,6 +21,7 @@ #include <string> #include <unordered_set> +#include "source/util/hash_combine.h" #include "source/util/make_unique.h" #include "spirv/unified1/spirv.h" @@ -28,6 +29,7 @@ namespace spvtools { namespace opt { namespace analysis { +using spvtools::utils::hash_combine; using U32VecVec = std::vector<std::vector<uint32_t>>; namespace { @@ -182,23 +184,26 @@ bool Type::operator==(const Type& other) const { } } -void Type::GetHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* seen) const { - if (!seen->insert(this).second) { - return; +size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const { + // Linear search through a dense, cache coherent vector is faster than O(log + // n) search in a complex data structure (eg std::set) for the generally small + // number of nodes. It also skips the overhead of an new/delete per Type + // (when inserting/removing from a set). + if (std::find(seen->begin(), seen->end(), this) != seen->end()) { + return hash; } - words->push_back(kind_); + seen->push_back(this); + + hash = hash_combine(hash, uint32_t(kind_)); for (const auto& d : decorations_) { - for (auto w : d) { - words->push_back(w); - } + hash = hash_combine(hash, d); } switch (kind_) { -#define DeclareKindCase(type) \ - case k##type: \ - As##type()->GetExtraHashWords(words, seen); \ +#define DeclareKindCase(type) \ + case k##type: \ + hash = As##type()->ComputeExtraStateHash(hash, seen); \ break DeclareKindCase(Void); DeclareKindCase(Bool); @@ -232,18 +237,13 @@ void Type::GetHashWords(std::vector<uint32_t>* words, break; } - seen->erase(this); + seen->pop_back(); + return hash; } size_t Type::HashValue() const { - std::u32string h; - std::vector<uint32_t> words; - GetHashWords(&words); - for (auto w : words) { - h.push_back(w); - } - - return std::hash<std::u32string>()(h); + SeenTypes seen; + return ComputeHashValue(0, &seen); } bool Integer::IsSameImpl(const Type* that, IsSameCache*) const { @@ -258,10 +258,8 @@ std::string Integer::str() const { return oss.str(); } -void Integer::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>*) const { - words->push_back(width_); - words->push_back(signed_); +size_t Integer::ComputeExtraStateHash(size_t hash, SeenTypes*) const { + return hash_combine(hash, width_, signed_); } bool Float::IsSameImpl(const Type* that, IsSameCache*) const { @@ -275,9 +273,8 @@ std::string Float::str() const { return oss.str(); } -void Float::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>*) const { - words->push_back(width_); +size_t Float::ComputeExtraStateHash(size_t hash, SeenTypes*) const { + return hash_combine(hash, width_); } Vector::Vector(const Type* type, uint32_t count) @@ -299,10 +296,11 @@ std::string Vector::str() const { return oss.str(); } -void Vector::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* seen) const { - element_type_->GetHashWords(words, seen); - words->push_back(count_); +size_t Vector::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const { + // prefer form that doesn't require push/pop from stack: add state and + // make tail call. + hash = hash_combine(hash, count_); + return element_type_->ComputeHashValue(hash, seen); } Matrix::Matrix(const Type* type, uint32_t count) @@ -324,10 +322,9 @@ std::string Matrix::str() const { return oss.str(); } -void Matrix::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* seen) const { - element_type_->GetHashWords(words, seen); - words->push_back(count_); +size_t Matrix::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const { + hash = hash_combine(hash, count_); + return element_type_->ComputeHashValue(hash, seen); } Image::Image(Type* type, SpvDim dimen, uint32_t d, bool array, bool multisample, @@ -362,16 +359,10 @@ std::string Image::str() const { return oss.str(); } -void Image::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* seen) const { - sampled_type_->GetHashWords(words, seen); - words->push_back(dim_); - words->push_back(depth_); - words->push_back(arrayed_); - words->push_back(ms_); - words->push_back(sampled_); - words->push_back(format_); - words->push_back(access_qualifier_); +size_t Image::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const { + hash = hash_combine(hash, uint32_t(dim_), depth_, arrayed_, ms_, sampled_, + uint32_t(format_), uint32_t(access_qualifier_)); + return sampled_type_->ComputeHashValue(hash, seen); } bool SampledImage::IsSameImpl(const Type* that, IsSameCache* seen) const { @@ -387,9 +378,8 @@ std::string SampledImage::str() const { return oss.str(); } -void SampledImage::GetExtraHashWords( - std::vector<uint32_t>* words, std::unordered_set<const Type*>* seen) const { - image_type_->GetHashWords(words, seen); +size_t SampledImage::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const { + return image_type_->ComputeHashValue(hash, seen); } Array::Array(const Type* type, const Array::LengthInfo& length_info_arg) @@ -422,16 +412,19 @@ std::string Array::str() const { return oss.str(); } -void Array::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* seen) const { - element_type_->GetHashWords(words, seen); - // This should mirror the logic in IsSameImpl - words->insert(words->end(), length_info_.words.begin(), - length_info_.words.end()); +size_t Array::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const { + hash = hash_combine(hash, length_info_.words); + return element_type_->ComputeHashValue(hash, seen); } void Array::ReplaceElementType(const Type* type) { element_type_ = type; } +Array::LengthInfo Array::GetConstantLengthInfo(uint32_t const_id, + uint32_t length) const { + std::vector<uint32_t> extra_words{LengthInfo::Case::kConstant, length}; + return {const_id, extra_words}; +} + RuntimeArray::RuntimeArray(const Type* type) : Type(kRuntimeArray), element_type_(type) { assert(!type->AsVoid()); @@ -450,9 +443,8 @@ std::string RuntimeArray::str() const { return oss.str(); } -void RuntimeArray::GetExtraHashWords( - std::vector<uint32_t>* words, std::unordered_set<const Type*>* seen) const { - element_type_->GetHashWords(words, seen); +size_t RuntimeArray::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const { + return element_type_->ComputeHashValue(hash, seen); } void RuntimeArray::ReplaceElementType(const Type* type) { @@ -509,19 +501,14 @@ std::string Struct::str() const { return oss.str(); } -void Struct::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* seen) const { +size_t Struct::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const { for (auto* t : element_types_) { - t->GetHashWords(words, seen); + hash = t->ComputeHashValue(hash, seen); } for (const auto& pair : element_decorations_) { - words->push_back(pair.first); - for (const auto& d : pair.second) { - for (auto w : d) { - words->push_back(w); - } - } + hash = hash_combine(hash, pair.first, pair.second); } + return hash; } bool Opaque::IsSameImpl(const Type* that, IsSameCache*) const { @@ -536,11 +523,8 @@ std::string Opaque::str() const { return oss.str(); } -void Opaque::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>*) const { - for (auto c : name_) { - words->push_back(static_cast<char32_t>(c)); - } +size_t Opaque::ComputeExtraStateHash(size_t hash, SeenTypes*) const { + return hash_combine(hash, name_); } Pointer::Pointer(const Type* type, SpvStorageClass sc) @@ -569,10 +553,9 @@ std::string Pointer::str() const { return os.str(); } -void Pointer::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* seen) const { - pointee_type_->GetHashWords(words, seen); - words->push_back(storage_class_); +size_t Pointer::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const { + hash = hash_combine(hash, uint32_t(storage_class_)); + return pointee_type_->ComputeHashValue(hash, seen); } void Pointer::SetPointeeType(const Type* type) { pointee_type_ = type; } @@ -606,12 +589,11 @@ std::string Function::str() const { return oss.str(); } -void Function::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* seen) const { - return_type_->GetHashWords(words, seen); +size_t Function::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const { for (const auto* t : param_types_) { - t->GetHashWords(words, seen); + hash = t->ComputeHashValue(hash, seen); } + return return_type_->ComputeHashValue(hash, seen); } void Function::SetReturnType(const Type* type) { return_type_ = type; } @@ -628,9 +610,8 @@ std::string Pipe::str() const { return oss.str(); } -void Pipe::GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>*) const { - words->push_back(access_qualifier_); +size_t Pipe::ComputeExtraStateHash(size_t hash, SeenTypes*) const { + return hash_combine(hash, uint32_t(access_qualifier_)); } bool ForwardPointer::IsSameImpl(const Type* that, IsSameCache*) const { @@ -653,11 +634,11 @@ std::string ForwardPointer::str() const { return oss.str(); } -void ForwardPointer::GetExtraHashWords( - std::vector<uint32_t>* words, std::unordered_set<const Type*>* seen) const { - words->push_back(target_id_); - words->push_back(storage_class_); - if (pointer_) pointer_->GetHashWords(words, seen); +size_t ForwardPointer::ComputeExtraStateHash(size_t hash, + SeenTypes* seen) const { + hash = hash_combine(hash, target_id_, uint32_t(storage_class_)); + if (pointer_) hash = pointer_->ComputeHashValue(hash, seen); + return hash; } CooperativeMatrixNV::CooperativeMatrixNV(const Type* type, const uint32_t scope, @@ -681,12 +662,10 @@ std::string CooperativeMatrixNV::str() const { return oss.str(); } -void CooperativeMatrixNV::GetExtraHashWords( - std::vector<uint32_t>* words, std::unordered_set<const Type*>* pSet) const { - component_type_->GetHashWords(words, pSet); - words->push_back(scope_id_); - words->push_back(rows_id_); - words->push_back(columns_id_); +size_t CooperativeMatrixNV::ComputeExtraStateHash(size_t hash, + SeenTypes* seen) const { + hash = hash_combine(hash, scope_id_, rows_id_, columns_id_); + return component_type_->ComputeHashValue(hash, seen); } bool CooperativeMatrixNV::IsSameImpl(const Type* that, diff --git a/source/opt/types.h b/source/opt/types.h index 9ecd41a6..f5a4a6be 100644 --- a/source/opt/types.h +++ b/source/opt/types.h @@ -28,6 +28,7 @@ #include "source/latest_version_spirv_header.h" #include "source/opt/instruction.h" +#include "source/util/small_vector.h" #include "spirv-tools/libspirv.h" namespace spvtools { @@ -67,6 +68,8 @@ class Type { public: typedef std::set<std::pair<const Pointer*, const Pointer*>> IsSameCache; + using SeenTypes = spvtools::utils::SmallVector<const Type*, 8>; + // Available subtypes. // // When adding a new derived class of Type, please add an entry to the enum. @@ -96,7 +99,8 @@ class Type { kNamedBarrier, kAccelerationStructureNV, kCooperativeMatrixNV, - kRayQueryKHR + kRayQueryKHR, + kLast }; Type(Kind k) : kind_(k) {} @@ -154,21 +158,7 @@ class Type { // Returns the hash value of this type. size_t HashValue() const; - // Adds the necessary words to compute a hash value of this type to |words|. - void GetHashWords(std::vector<uint32_t>* words) const { - std::unordered_set<const Type*> seen; - GetHashWords(words, &seen); - } - - // Adds the necessary words to compute a hash value of this type to |words|. - void GetHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* seen) const; - - // Adds necessary extra words for a subtype to calculate a hash value into - // |words|. - virtual void GetExtraHashWords( - std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const = 0; + size_t ComputeHashValue(size_t hash, SeenTypes* seen) const; // A bunch of methods for casting this type to a given type. Returns this if the // cast can be done, nullptr otherwise. @@ -204,6 +194,10 @@ class Type { DeclareCastMethod(RayQueryKHR) #undef DeclareCastMethod +protected: + // Add any type-specific state to |hash| and returns new hash. + virtual size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const = 0; + protected: // Decorations attached to this type. Each decoration is encoded as a vector // of uint32_t numbers. The first uint32_t number is the decoration value, @@ -232,8 +226,7 @@ class Integer : public Type { uint32_t width() const { return width_; } bool IsSigned() const { return signed_; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -253,8 +246,7 @@ class Float : public Type { const Float* AsFloat() const override { return this; } uint32_t width() const { return width_; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -274,8 +266,7 @@ class Vector : public Type { Vector* AsVector() override { return this; } const Vector* AsVector() const override { return this; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -296,8 +287,7 @@ class Matrix : public Type { Matrix* AsMatrix() override { return this; } const Matrix* AsMatrix() const override { return this; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -327,8 +317,7 @@ class Image : public Type { SpvImageFormat format() const { return format_; } SpvAccessQualifier access_qualifier() const { return access_qualifier_; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -355,8 +344,7 @@ class SampledImage : public Type { const Type* image_type() const { return image_type_; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -399,10 +387,10 @@ class Array : public Type { Array* AsArray() override { return this; } const Array* AsArray() const override { return this; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; void ReplaceElementType(const Type* element_type); + LengthInfo GetConstantLengthInfo(uint32_t const_id, uint32_t length) const; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -422,8 +410,7 @@ class RuntimeArray : public Type { RuntimeArray* AsRuntimeArray() override { return this; } const RuntimeArray* AsRuntimeArray() const override { return this; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; void ReplaceElementType(const Type* element_type); @@ -459,8 +446,7 @@ class Struct : public Type { Struct* AsStruct() override { return this; } const Struct* AsStruct() const override { return this; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -491,8 +477,7 @@ class Opaque : public Type { const std::string& name() const { return name_; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -512,8 +497,7 @@ class Pointer : public Type { Pointer* AsPointer() override { return this; } const Pointer* AsPointer() const override { return this; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; void SetPointeeType(const Type* type); @@ -539,8 +523,7 @@ class Function : public Type { const std::vector<const Type*>& param_types() const { return param_types_; } std::vector<const Type*>& param_types() { return param_types_; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>*) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; void SetReturnType(const Type* type); @@ -564,8 +547,7 @@ class Pipe : public Type { SpvAccessQualifier access_qualifier() const { return access_qualifier_; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -592,8 +574,7 @@ class ForwardPointer : public Type { ForwardPointer* AsForwardPointer() override { return this; } const ForwardPointer* AsForwardPointer() const override { return this; } - void GetExtraHashWords(std::vector<uint32_t>* words, - std::unordered_set<const Type*>* pSet) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; private: bool IsSameImpl(const Type* that, IsSameCache*) const override; @@ -616,8 +597,7 @@ class CooperativeMatrixNV : public Type { return this; } - void GetExtraHashWords(std::vector<uint32_t>*, - std::unordered_set<const Type*>*) const override; + size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override; const Type* component_type() const { return component_type_; } uint32_t scope_id() const { return scope_id_; } @@ -633,24 +613,25 @@ class CooperativeMatrixNV : public Type { const uint32_t columns_id_; }; -#define DefineParameterlessType(type, name) \ - class type : public Type { \ - public: \ - type() : Type(k##type) {} \ - type(const type&) = default; \ - \ - std::string str() const override { return #name; } \ - \ - type* As##type() override { return this; } \ - const type* As##type() const override { return this; } \ - \ - void GetExtraHashWords(std::vector<uint32_t>*, \ - std::unordered_set<const Type*>*) const override {} \ - \ - private: \ - bool IsSameImpl(const Type* that, IsSameCache*) const override { \ - return that->As##type() && HasSameDecorations(that); \ - } \ +#define DefineParameterlessType(type, name) \ + class type : public Type { \ + public: \ + type() : Type(k##type) {} \ + type(const type&) = default; \ + \ + std::string str() const override { return #name; } \ + \ + type* As##type() override { return this; } \ + const type* As##type() const override { return this; } \ + \ + size_t ComputeExtraStateHash(size_t hash, SeenTypes*) const override { \ + return hash; \ + } \ + \ + private: \ + bool IsSameImpl(const Type* that, IsSameCache*) const override { \ + return that->As##type() && HasSameDecorations(that); \ + } \ } DefineParameterlessType(Void, void); DefineParameterlessType(Bool, bool); diff --git a/source/opt/unify_const_pass.cpp b/source/opt/unify_const_pass.cpp index 227fd61d..6bfa11a5 100644 --- a/source/opt/unify_const_pass.cpp +++ b/source/opt/unify_const_pass.cpp @@ -151,7 +151,7 @@ Pass::Status UnifyConstantPass::Process() { // 'SpecId' decoration and all of them should be treated as unique. // 'SpecId' is not applicable to SpecConstants defined with // OpSpecConstant{Op|Composite}, their values are not necessary to be - // unique. When all the operands/compoents are the same between two + // unique. When all the operands/components are the same between two // OpSpecConstant{Op|Composite} results, their result values must be the // same so are unifiable. case SpvOp::SpvOpSpecConstantOp: diff --git a/source/opt/upgrade_memory_model.cpp b/source/opt/upgrade_memory_model.cpp index ab252059..9d6a5bce 100644 --- a/source/opt/upgrade_memory_model.cpp +++ b/source/opt/upgrade_memory_model.cpp @@ -20,6 +20,7 @@ #include "source/opt/ir_context.h" #include "source/spirv_constant.h" #include "source/util/make_unique.h" +#include "source/util/string_utils.h" namespace spvtools { namespace opt { @@ -58,9 +59,7 @@ void UpgradeMemoryModel::UpgradeMemoryModelInstruction() { std::initializer_list<Operand>{ {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityVulkanMemoryModelKHR}}})); const std::string extension = "SPV_KHR_vulkan_memory_model"; - std::vector<uint32_t> words(extension.size() / 4 + 1, 0); - char* dst = reinterpret_cast<char*>(words.data()); - strncpy(dst, extension.c_str(), extension.size()); + std::vector<uint32_t> words = spvtools::utils::MakeVector(extension); context()->AddExtension( MakeUnique<Instruction>(context(), SpvOpExtension, 0, 0, std::initializer_list<Operand>{ @@ -85,8 +84,7 @@ void UpgradeMemoryModel::UpgradeInstructions() { if (ext_inst == GLSLstd450Modf || ext_inst == GLSLstd450Frexp) { auto import = get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u)); - if (reinterpret_cast<char*>(import->GetInOperand(0u).words.data()) == - std::string("GLSL.std.450")) { + if (import->GetInOperand(0u).AsString() == "GLSL.std.450") { UpgradeExtInst(inst); } } diff --git a/source/opt/vector_dce.h b/source/opt/vector_dce.h index 4d30b926..a55bda69 100644 --- a/source/opt/vector_dce.h +++ b/source/opt/vector_dce.h @@ -73,7 +73,7 @@ class VectorDCE : public MemPass { bool RewriteInstructions(Function* function, const LiveComponentMap& live_components); - // Makrs all DebugValue instructions that use |composite| for their values as + // Makes all DebugValue instructions that use |composite| for their values as // dead instructions by putting them into |dead_dbg_value|. void MarkDebugValueUsesAsDead(Instruction* composite, std::vector<Instruction*>* dead_dbg_value); |